diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index ce14d3c..c97f8cc 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -50,10 +50,9 @@ class KBService(ABC): def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings: return load_embeddings(self.embed_model, embed_device) - def save_vector_store(self, vector_store=None): + def save_vector_store(self): ''' - 保存向量库,仅支持FAISS。对于其它向量库该函数不做任何操作。 - 减少FAISS向量库操作时的类型判断。 + 保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持 ''' pass diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 444765f..f45a8d9 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -22,6 +22,10 @@ class MilvusKBService(KBService): from pymilvus import Collection return Collection(milvus_name) + def save_vector_store(self): + if self.milvus.col: + self.milvus.col.flush() + def get_doc_by_id(self, id: str) -> Optional[Document]: if self.milvus.col: data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"]) @@ -56,6 +60,7 @@ class MilvusKBService(KBService): def do_drop_kb(self): if self.milvus.col: + self.milvus.col.release() self.milvus.col.drop() def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): @@ -63,6 +68,11 @@ class MilvusKBService(KBService): return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: + # TODO: workaround for bug #10492 in langchain==0.0.286 + for doc in docs: + for field in self.milvus.fields: + doc.metadata.setdefault(field, "") + ids = self.milvus.add_documents(docs) doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] return doc_infos @@ -76,7 +86,8 @@ class MilvusKBService(KBService): def do_clear_vs(self): if self.milvus.col: - self.milvus.col.drop() + self.do_drop_kb() + self.do_init() if __name__ == '__main__': diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index e6381fa..fa832ab 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -26,6 +26,7 @@ class PGKBService(KBService): collection_name=self.kb_name, distance_strategy=DistanceStrategy.EUCLIDEAN, connection_string=kbs_config.get("pg").get("connection_uri")) + def get_doc_by_id(self, id: str) -> Optional[Document]: with self.pg_vector.connect() as connect: stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id") @@ -77,6 +78,7 @@ class PGKBService(KBService): def do_clear_vs(self): self.pg_vector.delete_collection() + self.pg_vector.create_collection() if __name__ == '__main__': diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 118ae37..893c37d 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -68,9 +68,7 @@ def folder2db( kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True) else: print(result) - - if kb.vs_type() == SupportedVSType.FAISS: - kb.save_vector_store() + kb.save_vector_store() elif mode == "fill_info_only": files = list_files_from_folder(kb_name) kb_files = file_to_kbfile(kb_name, files) @@ -84,9 +82,7 @@ def folder2db( for kb_file in kb_files: kb.update_doc(kb_file, not_refresh_vs_cache=True) - - if kb.vs_type() == SupportedVSType.FAISS: - kb.save_vector_store() + kb.save_vector_store() elif mode == "increament": db_files = kb.list_files() folder_files = list_files_from_folder(kb_name) @@ -101,9 +97,7 @@ def folder2db( kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True) else: print(result) - - if kb.vs_type() == SupportedVSType.FAISS: - kb.save_vector_store() + kb.save_vector_store() else: print(f"unspported migrate mode: {mode}") @@ -133,8 +127,7 @@ def prune_db_files(kb_name: str): kb_files = file_to_kbfile(kb_name, files) for kb_file in kb_files: kb.delete_doc(kb_file, not_refresh_vs_cache=True) - if kb.vs_type() == SupportedVSType.FAISS: - kb.save_vector_store() + kb.save_vector_store() return kb_files def prune_folder_files(kb_name: str):