开发者:XXKBService.get_doc_by_id 改为批量处理,提高访问向量库效率。
This commit is contained in:
parent
fbe214471b
commit
68a544ea33
|
|
@ -172,15 +172,15 @@ class KBService(ABC):
|
||||||
docs = self.do_search(query, top_k, score_threshold)
|
docs = self.do_search(query, top_k, score_threshold)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
return None
|
return []
|
||||||
|
|
||||||
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
|
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
|
||||||
'''
|
'''
|
||||||
通过file_name或metadata检索Document
|
通过file_name或metadata检索Document
|
||||||
'''
|
'''
|
||||||
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
||||||
docs = [self.get_doc_by_id(x["id"]) for x in doc_infos]
|
docs = self.get_doc_by_ids([x["id"] for x in doc_infos])
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -32,9 +32,9 @@ class FaissKBService(KBService):
|
||||||
def save_vector_store(self):
|
def save_vector_store(self):
|
||||||
self.load_vector_store().save(self.vs_path)
|
self.load_vector_store().save(self.vs_path)
|
||||||
|
|
||||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
with self.load_vector_store().acquire() as vs:
|
with self.load_vector_store().acquire() as vs:
|
||||||
return vs.docstore._dict.get(id)
|
return [vs.docstore._dict.get(id) for id in ids]
|
||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
self.vector_name = self.vector_name or self.embed_model
|
self.vector_name = self.vector_name or self.embed_model
|
||||||
|
|
|
||||||
|
|
@ -22,13 +22,14 @@ class MilvusKBService(KBService):
|
||||||
# if self.milvus.col:
|
# if self.milvus.col:
|
||||||
# self.milvus.col.flush()
|
# self.milvus.col.flush()
|
||||||
|
|
||||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
|
result = []
|
||||||
if self.milvus.col:
|
if self.milvus.col:
|
||||||
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
|
data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"])
|
||||||
if len(data_list) > 0:
|
for data in data_list:
|
||||||
data = data_list[0]
|
|
||||||
text = data.pop("text")
|
text = data.pop("text")
|
||||||
return Document(page_content=text, metadata=data)
|
result.append(Document(page_content=text, metadata=data))
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def search(milvus_name, content, limit=3):
|
def search(milvus_name, content, limit=3):
|
||||||
|
|
@ -99,7 +100,7 @@ if __name__ == '__main__':
|
||||||
milvusService = MilvusKBService("test")
|
milvusService = MilvusKBService("test")
|
||||||
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||||
|
|
||||||
print(milvusService.get_doc_by_id("444022434274215486"))
|
print(milvusService.get_doc_by_ids(["444022434274215486"]))
|
||||||
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||||
# milvusService.do_drop_kb()
|
# milvusService.do_drop_kb()
|
||||||
# print(milvusService.search_docs("如何启动api服务"))
|
# print(milvusService.search_docs("如何启动api服务"))
|
||||||
|
|
|
||||||
|
|
@ -22,13 +22,12 @@ class PGKBService(KBService):
|
||||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||||
|
|
||||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
with self.pg_vector.connect() as connect:
|
with self.pg_vector.connect() as connect:
|
||||||
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
|
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
|
||||||
results = [Document(page_content=row[0], metadata=row[1]) for row in
|
results = [Document(page_content=row[0], metadata=row[1]) for row in
|
||||||
connect.execute(stmt, parameters={'id': id}).fetchall()]
|
connect.execute(stmt, parameters={'ids': ids}).fetchall()]
|
||||||
if len(results) > 0:
|
return results
|
||||||
return results[0]
|
|
||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
self._load_pg_vector()
|
self._load_pg_vector()
|
||||||
|
|
@ -88,5 +87,5 @@ if __name__ == '__main__':
|
||||||
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
|
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
|
||||||
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
|
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||||
# pGKBService.drop_kb()
|
# pGKBService.drop_kb()
|
||||||
print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772"))
|
print(pGKBService.get_doc_by_ids(["f1e51390-3029-4a19-90dc-7118aaa25772"]))
|
||||||
# print(pGKBService.search_docs("如何启动api服务"))
|
# print(pGKBService.search_docs("如何启动api服务"))
|
||||||
|
|
|
||||||
|
|
@ -20,13 +20,14 @@ class ZillizKBService(KBService):
|
||||||
# if self.zilliz.col:
|
# if self.zilliz.col:
|
||||||
# self.zilliz.col.flush()
|
# self.zilliz.col.flush()
|
||||||
|
|
||||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
|
result = []
|
||||||
if self.zilliz.col:
|
if self.zilliz.col:
|
||||||
data_list = self.zilliz.col.query(expr=f'pk == {id}', output_fields=["*"])
|
data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"])
|
||||||
if len(data_list) > 0:
|
for data in data_list:
|
||||||
data = data_list[0]
|
|
||||||
text = data.pop("text")
|
text = data.pop("text")
|
||||||
return Document(page_content=text, metadata=data)
|
result.append(Document(page_content=text, metadata=data))
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def search(zilliz_name, content, limit=3):
|
def search(zilliz_name, content, limit=3):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue