修复milvus_kb_service中一些bug,添加文档后将数据同步到数据库 (#1452)

This commit is contained in:
liunux4odoo 2023-09-12 22:34:03 +08:00 committed by GitHub
parent 4aa14b859e
commit efd6d4a251
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 15 deletions

View File

@ -50,10 +50,9 @@ class KBService(ABC):
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
return load_embeddings(self.embed_model, embed_device)
def save_vector_store(self, vector_store=None):
def save_vector_store(self):
'''
保存向量库仅支持FAISS对于其它向量库该函数不做任何操作
减少FAISS向量库操作时的类型判断
保存向量库:FAISS保存到磁盘milvus保存到数据库PGVector暂未支持
'''
pass

View File

@ -22,6 +22,10 @@ class MilvusKBService(KBService):
from pymilvus import Collection
return Collection(milvus_name)
def save_vector_store(self):
if self.milvus.col:
self.milvus.col.flush()
def get_doc_by_id(self, id: str) -> Optional[Document]:
if self.milvus.col:
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
@ -56,6 +60,7 @@ class MilvusKBService(KBService):
def do_drop_kb(self):
if self.milvus.col:
self.milvus.col.release()
self.milvus.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
@ -63,6 +68,11 @@ class MilvusKBService(KBService):
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
# TODO: workaround for bug #10492 in langchain==0.0.286
for doc in docs:
for field in self.milvus.fields:
doc.metadata.setdefault(field, "")
ids = self.milvus.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
@ -76,7 +86,8 @@ class MilvusKBService(KBService):
def do_clear_vs(self):
if self.milvus.col:
self.milvus.col.drop()
self.do_drop_kb()
self.do_init()
if __name__ == '__main__':

View File

@ -26,6 +26,7 @@ class PGKBService(KBService):
collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri"))
def get_doc_by_id(self, id: str) -> Optional[Document]:
with self.pg_vector.connect() as connect:
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
@ -77,6 +78,7 @@ class PGKBService(KBService):
def do_clear_vs(self):
self.pg_vector.delete_collection()
self.pg_vector.create_collection()
if __name__ == '__main__':

View File

@ -68,8 +68,6 @@ def folder2db(
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
else:
print(result)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
elif mode == "fill_info_only":
files = list_files_from_folder(kb_name)
@ -84,8 +82,6 @@ def folder2db(
for kb_file in kb_files:
kb.update_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
elif mode == "increament":
db_files = kb.list_files()
@ -101,8 +97,6 @@ def folder2db(
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
else:
print(result)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
else:
print(f"unspported migrate mode: {mode}")
@ -133,7 +127,6 @@ def prune_db_files(kb_name: str):
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
return kb_files