修复milvus_kb_service中一些bug,添加文档后将数据同步到数据库 (#1452)
This commit is contained in:
parent
4aa14b859e
commit
efd6d4a251
|
|
@ -50,10 +50,9 @@ class KBService(ABC):
|
||||||
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
||||||
return load_embeddings(self.embed_model, embed_device)
|
return load_embeddings(self.embed_model, embed_device)
|
||||||
|
|
||||||
def save_vector_store(self, vector_store=None):
|
def save_vector_store(self):
|
||||||
'''
|
'''
|
||||||
保存向量库,仅支持FAISS。对于其它向量库该函数不做任何操作。
|
保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持
|
||||||
减少FAISS向量库操作时的类型判断。
|
|
||||||
'''
|
'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,10 @@ class MilvusKBService(KBService):
|
||||||
from pymilvus import Collection
|
from pymilvus import Collection
|
||||||
return Collection(milvus_name)
|
return Collection(milvus_name)
|
||||||
|
|
||||||
|
def save_vector_store(self):
|
||||||
|
if self.milvus.col:
|
||||||
|
self.milvus.col.flush()
|
||||||
|
|
||||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||||
if self.milvus.col:
|
if self.milvus.col:
|
||||||
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
|
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
|
||||||
|
|
@ -56,6 +60,7 @@ class MilvusKBService(KBService):
|
||||||
|
|
||||||
def do_drop_kb(self):
|
def do_drop_kb(self):
|
||||||
if self.milvus.col:
|
if self.milvus.col:
|
||||||
|
self.milvus.col.release()
|
||||||
self.milvus.col.drop()
|
self.milvus.col.drop()
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||||
|
|
@ -63,6 +68,11 @@ class MilvusKBService(KBService):
|
||||||
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
|
# TODO: workaround for bug #10492 in langchain==0.0.286
|
||||||
|
for doc in docs:
|
||||||
|
for field in self.milvus.fields:
|
||||||
|
doc.metadata.setdefault(field, "")
|
||||||
|
|
||||||
ids = self.milvus.add_documents(docs)
|
ids = self.milvus.add_documents(docs)
|
||||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||||
return doc_infos
|
return doc_infos
|
||||||
|
|
@ -76,7 +86,8 @@ class MilvusKBService(KBService):
|
||||||
|
|
||||||
def do_clear_vs(self):
|
def do_clear_vs(self):
|
||||||
if self.milvus.col:
|
if self.milvus.col:
|
||||||
self.milvus.col.drop()
|
self.do_drop_kb()
|
||||||
|
self.do_init()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ class PGKBService(KBService):
|
||||||
collection_name=self.kb_name,
|
collection_name=self.kb_name,
|
||||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||||
|
|
||||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||||
with self.pg_vector.connect() as connect:
|
with self.pg_vector.connect() as connect:
|
||||||
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
|
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
|
||||||
|
|
@ -77,6 +78,7 @@ class PGKBService(KBService):
|
||||||
|
|
||||||
def do_clear_vs(self):
|
def do_clear_vs(self):
|
||||||
self.pg_vector.delete_collection()
|
self.pg_vector.delete_collection()
|
||||||
|
self.pg_vector.create_collection()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
||||||
|
|
@ -68,9 +68,7 @@ def folder2db(
|
||||||
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
|
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
|
||||||
else:
|
else:
|
||||||
print(result)
|
print(result)
|
||||||
|
kb.save_vector_store()
|
||||||
if kb.vs_type() == SupportedVSType.FAISS:
|
|
||||||
kb.save_vector_store()
|
|
||||||
elif mode == "fill_info_only":
|
elif mode == "fill_info_only":
|
||||||
files = list_files_from_folder(kb_name)
|
files = list_files_from_folder(kb_name)
|
||||||
kb_files = file_to_kbfile(kb_name, files)
|
kb_files = file_to_kbfile(kb_name, files)
|
||||||
|
|
@ -84,9 +82,7 @@ def folder2db(
|
||||||
|
|
||||||
for kb_file in kb_files:
|
for kb_file in kb_files:
|
||||||
kb.update_doc(kb_file, not_refresh_vs_cache=True)
|
kb.update_doc(kb_file, not_refresh_vs_cache=True)
|
||||||
|
kb.save_vector_store()
|
||||||
if kb.vs_type() == SupportedVSType.FAISS:
|
|
||||||
kb.save_vector_store()
|
|
||||||
elif mode == "increament":
|
elif mode == "increament":
|
||||||
db_files = kb.list_files()
|
db_files = kb.list_files()
|
||||||
folder_files = list_files_from_folder(kb_name)
|
folder_files = list_files_from_folder(kb_name)
|
||||||
|
|
@ -101,9 +97,7 @@ def folder2db(
|
||||||
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
|
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
|
||||||
else:
|
else:
|
||||||
print(result)
|
print(result)
|
||||||
|
kb.save_vector_store()
|
||||||
if kb.vs_type() == SupportedVSType.FAISS:
|
|
||||||
kb.save_vector_store()
|
|
||||||
else:
|
else:
|
||||||
print(f"unspported migrate mode: {mode}")
|
print(f"unspported migrate mode: {mode}")
|
||||||
|
|
||||||
|
|
@ -133,8 +127,7 @@ def prune_db_files(kb_name: str):
|
||||||
kb_files = file_to_kbfile(kb_name, files)
|
kb_files = file_to_kbfile(kb_name, files)
|
||||||
for kb_file in kb_files:
|
for kb_file in kb_files:
|
||||||
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
|
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
|
||||||
if kb.vs_type() == SupportedVSType.FAISS:
|
kb.save_vector_store()
|
||||||
kb.save_vector_store()
|
|
||||||
return kb_files
|
return kb_files
|
||||||
|
|
||||||
def prune_folder_files(kb_name: str):
|
def prune_folder_files(kb_name: str):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue