开发者:XXKBService.get_doc_by_id 改为批量处理,提高访问向量库效率。
This commit is contained in:
parent
fbe214471b
commit
68a544ea33
|
|
@ -172,15 +172,15 @@ class KBService(ABC):
|
|||
docs = self.do_search(query, top_k, score_threshold)
|
||||
return docs
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
return None
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
return []
|
||||
|
||||
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
|
||||
'''
|
||||
通过file_name或metadata检索Document
|
||||
'''
|
||||
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
||||
docs = [self.get_doc_by_id(x["id"]) for x in doc_infos]
|
||||
docs = self.get_doc_by_ids([x["id"] for x in doc_infos])
|
||||
return docs
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -32,9 +32,9 @@ class FaissKBService(KBService):
|
|||
def save_vector_store(self):
|
||||
self.load_vector_store().save(self.vs_path)
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
return vs.docstore._dict.get(id)
|
||||
return [vs.docstore._dict.get(id) for id in ids]
|
||||
|
||||
def do_init(self):
|
||||
self.vector_name = self.vector_name or self.embed_model
|
||||
|
|
|
|||
|
|
@ -22,13 +22,14 @@ class MilvusKBService(KBService):
|
|||
# if self.milvus.col:
|
||||
# self.milvus.col.flush()
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
result = []
|
||||
if self.milvus.col:
|
||||
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
|
||||
if len(data_list) > 0:
|
||||
data = data_list[0]
|
||||
data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"])
|
||||
for data in data_list:
|
||||
text = data.pop("text")
|
||||
return Document(page_content=text, metadata=data)
|
||||
result.append(Document(page_content=text, metadata=data))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def search(milvus_name, content, limit=3):
|
||||
|
|
@ -99,7 +100,7 @@ if __name__ == '__main__':
|
|||
milvusService = MilvusKBService("test")
|
||||
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
|
||||
print(milvusService.get_doc_by_id("444022434274215486"))
|
||||
print(milvusService.get_doc_by_ids(["444022434274215486"]))
|
||||
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
# milvusService.do_drop_kb()
|
||||
# print(milvusService.search_docs("如何启动api服务"))
|
||||
|
|
|
|||
|
|
@ -22,13 +22,12 @@ class PGKBService(KBService):
|
|||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
with self.pg_vector.connect() as connect:
|
||||
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
|
||||
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
|
||||
results = [Document(page_content=row[0], metadata=row[1]) for row in
|
||||
connect.execute(stmt, parameters={'id': id}).fetchall()]
|
||||
if len(results) > 0:
|
||||
return results[0]
|
||||
connect.execute(stmt, parameters={'ids': ids}).fetchall()]
|
||||
return results
|
||||
|
||||
def do_init(self):
|
||||
self._load_pg_vector()
|
||||
|
|
@ -88,5 +87,5 @@ if __name__ == '__main__':
|
|||
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
# pGKBService.drop_kb()
|
||||
print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772"))
|
||||
print(pGKBService.get_doc_by_ids(["f1e51390-3029-4a19-90dc-7118aaa25772"]))
|
||||
# print(pGKBService.search_docs("如何启动api服务"))
|
||||
|
|
|
|||
|
|
@ -20,13 +20,14 @@ class ZillizKBService(KBService):
|
|||
# if self.zilliz.col:
|
||||
# self.zilliz.col.flush()
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
result = []
|
||||
if self.zilliz.col:
|
||||
data_list = self.zilliz.col.query(expr=f'pk == {id}', output_fields=["*"])
|
||||
if len(data_list) > 0:
|
||||
data = data_list[0]
|
||||
data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"])
|
||||
for data in data_list:
|
||||
text = data.pop("text")
|
||||
return Document(page_content=text, metadata=data)
|
||||
result.append(Document(page_content=text, metadata=data))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def search(zilliz_name, content, limit=3):
|
||||
|
|
|
|||
Loading…
Reference in New Issue