diff --git a/configs/model_config.py.example b/configs/model_config.py.example index ecbabac..b73ccd9 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -280,10 +280,13 @@ BING_SUBSCRIPTION_KEY = "" kbs_config = { "faiss": { - }, "milvus": { - "milvus_host": "192.168.50.128", - "milvus_port": 19530 + "host": "127.0.0.1", + "port": "19530", + "user": "", + "password": "", + "secure": False, } -} \ No newline at end of file +} +DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") \ No newline at end of file diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 006ca81..c7fb24e 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -195,11 +195,9 @@ class KBService(ABC): def __init__(self, knowledge_base_name: str, - vector_store_type: str = "faiss", embed_model: str = EMBEDDING_MODEL, ): self.kb_name = knowledge_base_name - self.vs_type = vector_store_type self.embed_model = embed_model self.kb_path = get_kb_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name) @@ -212,7 +210,7 @@ class KBService(ABC): if not os.path.exists(self.doc_path): os.makedirs(self.doc_path) self.do_create_kb() - status = add_kb_to_db(self.kb_name, self.vs_type, self.embed_model) + status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model) return status def clear_vs(self): @@ -225,7 +223,7 @@ class KBService(ABC): """ 删除知识库 """ - self.do_remove_kb() + self.do_drop_kb() status = delete_kb_from_db(self.kb_name) return status @@ -245,7 +243,7 @@ class KBService(ABC): """ if os.path.exists(kb_file.filepath): os.remove(kb_file.filepath) - self.do_delete(kb_file) + self.do_delete_doc(kb_file) status = delete_file_from_db(kb_file) return status @@ -293,7 +291,7 @@ class KBService(ABC): pass @abstractmethod - def do_remove_kb(self): + def do_drop_kb(self): """ 删除知识库子类实自己逻辑 """ @@ -320,7 +318,7 @@ class KBService(ABC): pass @abstractmethod - def do_delete(self, + def do_delete_doc(self, kb_file: KnowledgeFile): """ 从知识库删除文档子类实自己逻辑 diff --git a/server/knowledge_base/kb_service/default_kb_service.py b/server/knowledge_base/kb_service/default_kb_service.py index 3a6e0a5..fb31859 100644 --- a/server/knowledge_base/kb_service/default_kb_service.py +++ b/server/knowledge_base/kb_service/default_kb_service.py @@ -11,7 +11,7 @@ class DefaultKBService(KBService): def do_init(self): pass - def do_remove_kbs(self): + def do_drop_kbs(self): pass def do_search(self): @@ -23,5 +23,5 @@ class DefaultKBService(KBService): def do_insert_one_knowledge(self): pass - def do_delete(self): + def do_delete_doc(self): pass diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 3141754..64ed9ed 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -76,7 +76,7 @@ class FaissKBService(KBService): if not os.path.exists(self.vs_path): os.makedirs(self.vs_path) - def do_remove_kb(self): + def do_drop_kb(self): shutil.rmtree(self.kb_path) def do_search(self, @@ -105,7 +105,7 @@ class FaissKBService(KBService): vector_store.save_local(self.vs_path) refresh_vs_cache(self.kb_name) - def do_delete(self, + def do_delete_doc(self, kb_file: KnowledgeFile): embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): @@ -113,7 +113,6 @@ class FaissKBService(KBService): ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] if len(ids) == 0: return None - print(len(ids)) vector_store = delete_doc_from_faiss(vector_store, ids) vector_store.save_local(self.vs_path) refresh_vs_cache(self.kb_name) diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 5d43633..7dcceb4 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -1,65 +1,79 @@ -from pymilvus import ( - connections, - utility, - FieldSchema, - CollectionSchema, - DataType, - Collection, -) +from typing import List -from server.knowledge_base.kb_service.base import KBService +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +from langchain.vectorstores import Milvus + +from configs.model_config import EMBEDDING_DEVICE, kbs_config +from server.knowledge_base import KnowledgeFile +from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings, add_doc_to_db -def get_collection(milvus_name): - return Collection(milvus_name) +class MilvusKBService(KBService): + milvus: Milvus + @staticmethod + def get_collection(milvus_name): + from pymilvus import Collection + return Collection(milvus_name) -def search(milvus_name, content, limit=3): - search_params = { - "metric_type": "L2", - "params": {"nprobe": 10}, - } - c = get_collection(milvus_name) - return c.search(content, "embeddings", search_params, limit=limit, output_fields=["random"]) - - -class MilvusKBService(): - milvus_host: str - milvus_port: int - dim: int - - def __init__(self, knowledge_base_name: str, vector_store_type: str, milvus_host="localhost", milvus_port=19530, - dim=8): - - super().__init__(knowledge_base_name, vector_store_type) - self.milvus_host = milvus_host - self.milvus_port = milvus_port - self.dim = dim - - def connect(self): - connections.connect("default", host=self.milvus_host, port=self.milvus_port) - - def create_collection(self, milvus_name): - fields = [ - FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=False), - FieldSchema(name="content", dtype=DataType.STRING), - FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=self.dim) - ] - schema = CollectionSchema(fields) - collection = Collection(milvus_name, schema) - index = { - "index_type": "IVF_FLAT", + @staticmethod + def search(milvus_name, content, limit=3): + search_params = { "metric_type": "L2", - "params": {"nlist": 128}, + "params": {"nprobe": 10}, } - collection.create_index("embeddings", index) - collection.load() - return collection + c = MilvusKBService.get_collection(milvus_name) + return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"]) - def insert_collection(self, milvus_name, content=[]): - get_collection(milvus_name).insert(dataset) + def do_create_kb(self): + pass + + def vs_type(self) -> str: + return SupportedVSType.MILVUS + + def _load_milvus(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None): + _embeddings = embeddings + if _embeddings is None: + _embeddings = load_embeddings(self.embed_model, embedding_device) + self.milvus = Milvus(embedding_function=_embeddings, + collection_name=self.kb_name, connection_args=kbs_config.get("milvus")) + + def do_init(self): + self._load_milvus() + + def do_drop_kb(self): + self.milvus.col.drop() + + def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: + self._load_milvus(embeddings=embeddings) + return self.milvus.similarity_search(query, top_k) + + def add_doc(self, kb_file: KnowledgeFile): + """ + 向知识库添加文件 + """ + docs = kb_file.file2text() + self.milvus.add_documents(docs) + status = add_doc_to_db(kb_file) + return status + + def do_add_doc(self, docs: List[Document], embeddings: Embeddings): + pass + + def do_delete_doc(self, kb_file: KnowledgeFile): + filepath = kb_file.filepath.replace('\\', '\\\\') + delete_list = [item.get("pk") for item in + self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])] + self.milvus.col.delete(expr=f'pk in {delete_list}') + + def do_clear_vs(self): + self.milvus.col.drop() if __name__ == '__main__': - milvusService = MilvusService(milvus_host='192.168.50.128') - milvusService.insert_collection(test,dataset) + milvusService = MilvusKBService("test") + # milvusService.add_doc(KnowledgeFile("test.pdf", "test")) + # milvusService.delete_doc(KnowledgeFile("test.pdf", "test")) + milvusService.do_drop_kb() + print(milvusService.search_docs("测试")) diff --git a/server/knowledge_base/knowledge_base_factory.py b/server/knowledge_base/knowledge_base_factory.py index 793074e..7d61748 100644 --- a/server/knowledge_base/knowledge_base_factory.py +++ b/server/knowledge_base/knowledge_base_factory.py @@ -1,5 +1,6 @@ from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db from server.knowledge_base.kb_service.default_kb_service import DefaultKBService +from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService class KBServiceFactory: @@ -11,6 +12,9 @@ class KBServiceFactory: if SupportedVSType.FAISS == vector_store_type: from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService return FaissKBService(kb_name) + elif SupportedVSType.MILVUS == vector_store_type: + from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService + return MilvusKBService(kb_name) elif SupportedVSType.DEFAULT == vector_store_type: return DefaultKBService(kb_name) @@ -28,9 +32,9 @@ class KBServiceFactory: if __name__ == '__main__': KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS) init_db() - KBService.create_kbs() + KBService.create_kb() KBService = KBServiceFactory.get_default() print(KBService.list_kbs()) KBService = KBServiceFactory.get_service_by_name("test") print(KBService.list_docs()) - KBService.drop_kbs() + KBService.drop_kb()