修复milvus_kb_service中一些bug,添加文档后将数据同步到数据库 (#1452)
This commit is contained in:
parent
4aa14b859e
commit
efd6d4a251
|
|
@ -50,10 +50,9 @@ class KBService(ABC):
|
|||
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
||||
return load_embeddings(self.embed_model, embed_device)
|
||||
|
||||
def save_vector_store(self, vector_store=None):
|
||||
def save_vector_store(self):
|
||||
'''
|
||||
保存向量库,仅支持FAISS。对于其它向量库该函数不做任何操作。
|
||||
减少FAISS向量库操作时的类型判断。
|
||||
保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持
|
||||
'''
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,10 @@ class MilvusKBService(KBService):
|
|||
from pymilvus import Collection
|
||||
return Collection(milvus_name)
|
||||
|
||||
def save_vector_store(self):
|
||||
if self.milvus.col:
|
||||
self.milvus.col.flush()
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
if self.milvus.col:
|
||||
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
|
||||
|
|
@ -56,6 +60,7 @@ class MilvusKBService(KBService):
|
|||
|
||||
def do_drop_kb(self):
|
||||
if self.milvus.col:
|
||||
self.milvus.col.release()
|
||||
self.milvus.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||
|
|
@ -63,6 +68,11 @@ class MilvusKBService(KBService):
|
|||
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
# TODO: workaround for bug #10492 in langchain==0.0.286
|
||||
for doc in docs:
|
||||
for field in self.milvus.fields:
|
||||
doc.metadata.setdefault(field, "")
|
||||
|
||||
ids = self.milvus.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
return doc_infos
|
||||
|
|
@ -76,7 +86,8 @@ class MilvusKBService(KBService):
|
|||
|
||||
def do_clear_vs(self):
|
||||
if self.milvus.col:
|
||||
self.milvus.col.drop()
|
||||
self.do_drop_kb()
|
||||
self.do_init()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class PGKBService(KBService):
|
|||
collection_name=self.kb_name,
|
||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
with self.pg_vector.connect() as connect:
|
||||
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
|
||||
|
|
@ -77,6 +78,7 @@ class PGKBService(KBService):
|
|||
|
||||
def do_clear_vs(self):
|
||||
self.pg_vector.delete_collection()
|
||||
self.pg_vector.create_collection()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -68,8 +68,6 @@ def folder2db(
|
|||
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
|
||||
else:
|
||||
print(result)
|
||||
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.save_vector_store()
|
||||
elif mode == "fill_info_only":
|
||||
files = list_files_from_folder(kb_name)
|
||||
|
|
@ -84,8 +82,6 @@ def folder2db(
|
|||
|
||||
for kb_file in kb_files:
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=True)
|
||||
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.save_vector_store()
|
||||
elif mode == "increament":
|
||||
db_files = kb.list_files()
|
||||
|
|
@ -101,8 +97,6 @@ def folder2db(
|
|||
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
|
||||
else:
|
||||
print(result)
|
||||
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.save_vector_store()
|
||||
else:
|
||||
print(f"unspported migrate mode: {mode}")
|
||||
|
|
@ -133,7 +127,6 @@ def prune_db_files(kb_name: str):
|
|||
kb_files = file_to_kbfile(kb_name, files)
|
||||
for kb_file in kb_files:
|
||||
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.save_vector_store()
|
||||
return kb_files
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue