添加Milvus库 (#1011)
This commit is contained in:
parent
18d31f5116
commit
0f46185cfb
|
|
@ -280,10 +280,13 @@ BING_SUBSCRIPTION_KEY = ""
|
|||
|
||||
kbs_config = {
|
||||
"faiss": {
|
||||
|
||||
},
|
||||
"milvus": {
|
||||
"milvus_host": "192.168.50.128",
|
||||
"milvus_port": 19530
|
||||
"host": "127.0.0.1",
|
||||
"port": "19530",
|
||||
"user": "",
|
||||
"password": "",
|
||||
"secure": False,
|
||||
}
|
||||
}
|
||||
}
|
||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||
|
|
@ -195,11 +195,9 @@ class KBService(ABC):
|
|||
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
vector_store_type: str = "faiss",
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.vs_type = vector_store_type
|
||||
self.embed_model = embed_model
|
||||
self.kb_path = get_kb_path(self.kb_name)
|
||||
self.doc_path = get_doc_path(self.kb_name)
|
||||
|
|
@ -212,7 +210,7 @@ class KBService(ABC):
|
|||
if not os.path.exists(self.doc_path):
|
||||
os.makedirs(self.doc_path)
|
||||
self.do_create_kb()
|
||||
status = add_kb_to_db(self.kb_name, self.vs_type, self.embed_model)
|
||||
status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model)
|
||||
return status
|
||||
|
||||
def clear_vs(self):
|
||||
|
|
@ -225,7 +223,7 @@ class KBService(ABC):
|
|||
"""
|
||||
删除知识库
|
||||
"""
|
||||
self.do_remove_kb()
|
||||
self.do_drop_kb()
|
||||
status = delete_kb_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
|
|
@ -245,7 +243,7 @@ class KBService(ABC):
|
|||
"""
|
||||
if os.path.exists(kb_file.filepath):
|
||||
os.remove(kb_file.filepath)
|
||||
self.do_delete(kb_file)
|
||||
self.do_delete_doc(kb_file)
|
||||
status = delete_file_from_db(kb_file)
|
||||
return status
|
||||
|
||||
|
|
@ -293,7 +291,7 @@ class KBService(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_remove_kb(self):
|
||||
def do_drop_kb(self):
|
||||
"""
|
||||
删除知识库子类实自己逻辑
|
||||
"""
|
||||
|
|
@ -320,7 +318,7 @@ class KBService(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_delete(self,
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile):
|
||||
"""
|
||||
从知识库删除文档子类实自己逻辑
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class DefaultKBService(KBService):
|
|||
def do_init(self):
|
||||
pass
|
||||
|
||||
def do_remove_kbs(self):
|
||||
def do_drop_kbs(self):
|
||||
pass
|
||||
|
||||
def do_search(self):
|
||||
|
|
@ -23,5 +23,5 @@ class DefaultKBService(KBService):
|
|||
def do_insert_one_knowledge(self):
|
||||
pass
|
||||
|
||||
def do_delete(self):
|
||||
def do_delete_doc(self):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class FaissKBService(KBService):
|
|||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
|
||||
def do_remove_kb(self):
|
||||
def do_drop_kb(self):
|
||||
shutil.rmtree(self.kb_path)
|
||||
|
||||
def do_search(self,
|
||||
|
|
@ -105,7 +105,7 @@ class FaissKBService(KBService):
|
|||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
def do_delete(self,
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile):
|
||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
|
|
@ -113,7 +113,6 @@ class FaissKBService(KBService):
|
|||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
print(len(ids))
|
||||
vector_store = delete_doc_from_faiss(vector_store, ids)
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
|
|
|||
|
|
@ -1,65 +1,79 @@
|
|||
from pymilvus import (
|
||||
connections,
|
||||
utility,
|
||||
FieldSchema,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
Collection,
|
||||
)
|
||||
from typing import List
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBService
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Milvus
|
||||
|
||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
||||
from server.knowledge_base import KnowledgeFile
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings, add_doc_to_db
|
||||
|
||||
|
||||
def get_collection(milvus_name):
|
||||
return Collection(milvus_name)
|
||||
class MilvusKBService(KBService):
|
||||
milvus: Milvus
|
||||
|
||||
@staticmethod
|
||||
def get_collection(milvus_name):
|
||||
from pymilvus import Collection
|
||||
return Collection(milvus_name)
|
||||
|
||||
def search(milvus_name, content, limit=3):
|
||||
search_params = {
|
||||
"metric_type": "L2",
|
||||
"params": {"nprobe": 10},
|
||||
}
|
||||
c = get_collection(milvus_name)
|
||||
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["random"])
|
||||
|
||||
|
||||
class MilvusKBService():
|
||||
milvus_host: str
|
||||
milvus_port: int
|
||||
dim: int
|
||||
|
||||
def __init__(self, knowledge_base_name: str, vector_store_type: str, milvus_host="localhost", milvus_port=19530,
|
||||
dim=8):
|
||||
|
||||
super().__init__(knowledge_base_name, vector_store_type)
|
||||
self.milvus_host = milvus_host
|
||||
self.milvus_port = milvus_port
|
||||
self.dim = dim
|
||||
|
||||
def connect(self):
|
||||
connections.connect("default", host=self.milvus_host, port=self.milvus_port)
|
||||
|
||||
def create_collection(self, milvus_name):
|
||||
fields = [
|
||||
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=False),
|
||||
FieldSchema(name="content", dtype=DataType.STRING),
|
||||
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=self.dim)
|
||||
]
|
||||
schema = CollectionSchema(fields)
|
||||
collection = Collection(milvus_name, schema)
|
||||
index = {
|
||||
"index_type": "IVF_FLAT",
|
||||
@staticmethod
|
||||
def search(milvus_name, content, limit=3):
|
||||
search_params = {
|
||||
"metric_type": "L2",
|
||||
"params": {"nlist": 128},
|
||||
"params": {"nprobe": 10},
|
||||
}
|
||||
collection.create_index("embeddings", index)
|
||||
collection.load()
|
||||
return collection
|
||||
c = MilvusKBService.get_collection(milvus_name)
|
||||
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
|
||||
|
||||
def insert_collection(self, milvus_name, content=[]):
|
||||
get_collection(milvus_name).insert(dataset)
|
||||
def do_create_kb(self):
|
||||
pass
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.MILVUS
|
||||
|
||||
def _load_milvus(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None):
|
||||
_embeddings = embeddings
|
||||
if _embeddings is None:
|
||||
_embeddings = load_embeddings(self.embed_model, embedding_device)
|
||||
self.milvus = Milvus(embedding_function=_embeddings,
|
||||
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
|
||||
|
||||
def do_init(self):
|
||||
self._load_milvus()
|
||||
|
||||
def do_drop_kb(self):
|
||||
self.milvus.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
|
||||
self._load_milvus(embeddings=embeddings)
|
||||
return self.milvus.similarity_search(query, top_k)
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile):
|
||||
"""
|
||||
向知识库添加文件
|
||||
"""
|
||||
docs = kb_file.file2text()
|
||||
self.milvus.add_documents(docs)
|
||||
status = add_doc_to_db(kb_file)
|
||||
return status
|
||||
|
||||
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
|
||||
pass
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile):
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
delete_list = [item.get("pk") for item in
|
||||
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
||||
self.milvus.col.delete(expr=f'pk in {delete_list}')
|
||||
|
||||
def do_clear_vs(self):
|
||||
self.milvus.col.drop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
milvusService = MilvusService(milvus_host='192.168.50.128')
|
||||
milvusService.insert_collection(test,dataset)
|
||||
milvusService = MilvusKBService("test")
|
||||
# milvusService.add_doc(KnowledgeFile("test.pdf", "test"))
|
||||
# milvusService.delete_doc(KnowledgeFile("test.pdf", "test"))
|
||||
milvusService.do_drop_kb()
|
||||
print(milvusService.search_docs("测试"))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db
|
||||
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
|
||||
|
||||
class KBServiceFactory:
|
||||
|
|
@ -11,6 +12,9 @@ class KBServiceFactory:
|
|||
if SupportedVSType.FAISS == vector_store_type:
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
return FaissKBService(kb_name)
|
||||
elif SupportedVSType.MILVUS == vector_store_type:
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name)
|
||||
elif SupportedVSType.DEFAULT == vector_store_type:
|
||||
return DefaultKBService(kb_name)
|
||||
|
||||
|
|
@ -28,9 +32,9 @@ class KBServiceFactory:
|
|||
if __name__ == '__main__':
|
||||
KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS)
|
||||
init_db()
|
||||
KBService.create_kbs()
|
||||
KBService.create_kb()
|
||||
KBService = KBServiceFactory.get_default()
|
||||
print(KBService.list_kbs())
|
||||
KBService = KBServiceFactory.get_service_by_name("test")
|
||||
print(KBService.list_docs())
|
||||
KBService.drop_kbs()
|
||||
KBService.drop_kb()
|
||||
|
|
|
|||
Loading…
Reference in New Issue