Langchain-Chatchat/server/knowledge_base/kb_service/milvus_kb_service.py

66 lines
1.9 KiB
Python

from pymilvus import (
connections,
utility,
FieldSchema,
CollectionSchema,
DataType,
Collection,
)
from server.knowledge_base.kb_service.base import KBService
def get_collection(milvus_name):
return Collection(milvus_name)
def search(milvus_name, content, limit=3):
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
c = get_collection(milvus_name)
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["random"])
class MilvusKBService():
milvus_host: str
milvus_port: int
dim: int
def __init__(self, knowledge_base_name: str, vector_store_type: str, milvus_host="localhost", milvus_port=19530,
dim=8):
super().__init__(knowledge_base_name, vector_store_type)
self.milvus_host = milvus_host
self.milvus_port = milvus_port
self.dim = dim
def connect(self):
connections.connect("default", host=self.milvus_host, port=self.milvus_port)
def create_collection(self, milvus_name):
fields = [
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=False),
FieldSchema(name="content", dtype=DataType.STRING),
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=self.dim)
]
schema = CollectionSchema(fields)
collection = Collection(milvus_name, schema)
index = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}
collection.create_index("embeddings", index)
collection.load()
return collection
def insert_collection(self, milvus_name, content=[]):
get_collection(milvus_name).insert(dataset)
if __name__ == '__main__':
milvusService = MilvusService(milvus_host='192.168.50.128')
milvusService.insert_collection(test,dataset)