开发者:XXKBService.get_doc_by_id 改为批量处理,提高访问向量库效率。

This commit is contained in:
liunux4odoo 2023-11-16 11:09:40 +08:00
parent fbe214471b
commit 68a544ea33
5 changed files with 23 additions and 22 deletions

View File

@ -172,15 +172,15 @@ class KBService(ABC):
docs = self.do_search(query, top_k, score_threshold)
return docs
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
return []
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
'''
通过file_name或metadata检索Document
'''
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
docs = [self.get_doc_by_id(x["id"]) for x in doc_infos]
docs = self.get_doc_by_ids([x["id"] for x in doc_infos])
return docs
@abstractmethod

View File

@ -32,9 +32,9 @@ class FaissKBService(KBService):
def save_vector_store(self):
self.load_vector_store().save(self.vs_path)
def get_doc_by_id(self, id: str) -> Optional[Document]:
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
with self.load_vector_store().acquire() as vs:
return vs.docstore._dict.get(id)
return [vs.docstore._dict.get(id) for id in ids]
def do_init(self):
self.vector_name = self.vector_name or self.embed_model

View File

@ -22,13 +22,14 @@ class MilvusKBService(KBService):
# if self.milvus.col:
# self.milvus.col.flush()
def get_doc_by_id(self, id: str) -> Optional[Document]:
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = []
if self.milvus.col:
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
if len(data_list) > 0:
data = data_list[0]
data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
return Document(page_content=text, metadata=data)
result.append(Document(page_content=text, metadata=data))
return result
@staticmethod
def search(milvus_name, content, limit=3):
@ -99,7 +100,7 @@ if __name__ == '__main__':
milvusService = MilvusKBService("test")
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
print(milvusService.get_doc_by_id("444022434274215486"))
print(milvusService.get_doc_by_ids(["444022434274215486"]))
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
# milvusService.do_drop_kb()
# print(milvusService.search_docs("如何启动api服务"))

View File

@ -22,13 +22,12 @@ class PGKBService(KBService):
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri"))
def get_doc_by_id(self, id: str) -> Optional[Document]:
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
with self.pg_vector.connect() as connect:
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
results = [Document(page_content=row[0], metadata=row[1]) for row in
connect.execute(stmt, parameters={'id': id}).fetchall()]
if len(results) > 0:
return results[0]
connect.execute(stmt, parameters={'ids': ids}).fetchall()]
return results
def do_init(self):
self._load_pg_vector()
@ -88,5 +87,5 @@ if __name__ == '__main__':
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
# pGKBService.drop_kb()
print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772"))
print(pGKBService.get_doc_by_ids(["f1e51390-3029-4a19-90dc-7118aaa25772"]))
# print(pGKBService.search_docs("如何启动api服务"))

View File

@ -20,13 +20,14 @@ class ZillizKBService(KBService):
# if self.zilliz.col:
# self.zilliz.col.flush()
def get_doc_by_id(self, id: str) -> Optional[Document]:
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = []
if self.zilliz.col:
data_list = self.zilliz.col.query(expr=f'pk == {id}', output_fields=["*"])
if len(data_list) > 0:
data = data_list[0]
data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
return Document(page_content=text, metadata=data)
result.append(Document(page_content=text, metadata=data))
return result
@staticmethod
def search(zilliz_name, content, limit=3):