From 68a544ea334e7c1deeaffef3f6eb94059994d8a8 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Thu, 16 Nov 2023 11:09:40 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BC=80=E5=8F=91=E8=80=85=EF=BC=9AXXKBService?= =?UTF-8?q?.get=5Fdoc=5Fby=5Fid=20=E6=94=B9=E4=B8=BA=E6=89=B9=E9=87=8F?= =?UTF-8?q?=E5=A4=84=E7=90=86=EF=BC=8C=E6=8F=90=E9=AB=98=E8=AE=BF=E9=97=AE?= =?UTF-8?q?=E5=90=91=E9=87=8F=E5=BA=93=E6=95=88=E7=8E=87=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/knowledge_base/kb_service/base.py | 6 +++--- .../knowledge_base/kb_service/faiss_kb_service.py | 4 ++-- .../knowledge_base/kb_service/milvus_kb_service.py | 13 +++++++------ server/knowledge_base/kb_service/pg_kb_service.py | 11 +++++------ .../knowledge_base/kb_service/zilliz_kb_service.py | 11 ++++++----- 5 files changed, 23 insertions(+), 22 deletions(-) diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index e221e43..fa079f6 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -172,15 +172,15 @@ class KBService(ABC): docs = self.do_search(query, top_k, score_threshold) return docs - def get_doc_by_id(self, id: str) -> Optional[Document]: - return None + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: + return [] def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]: ''' 通过file_name或metadata检索Document ''' doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata) - docs = [self.get_doc_by_id(x["id"]) for x in doc_infos] + docs = self.get_doc_by_ids([x["id"] for x in doc_infos]) return docs @abstractmethod diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 07c57e0..a444f03 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -32,9 +32,9 @@ class FaissKBService(KBService): def save_vector_store(self): self.load_vector_store().save(self.vs_path) - def get_doc_by_id(self, id: str) -> Optional[Document]: + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: with self.load_vector_store().acquire() as vs: - return vs.docstore._dict.get(id) + return [vs.docstore._dict.get(id) for id in ids] def do_init(self): self.vector_name = self.vector_name or self.embed_model diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 5e27040..97c5913 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -22,13 +22,14 @@ class MilvusKBService(KBService): # if self.milvus.col: # self.milvus.col.flush() - def get_doc_by_id(self, id: str) -> Optional[Document]: + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: + result = [] if self.milvus.col: - data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"]) - if len(data_list) > 0: - data = data_list[0] + data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"]) + for data in data_list: text = data.pop("text") - return Document(page_content=text, metadata=data) + result.append(Document(page_content=text, metadata=data)) + return result @staticmethod def search(milvus_name, content, limit=3): @@ -99,7 +100,7 @@ if __name__ == '__main__': milvusService = MilvusKBService("test") # milvusService.add_doc(KnowledgeFile("README.md", "test")) - print(milvusService.get_doc_by_id("444022434274215486")) + print(milvusService.get_doc_by_ids(["444022434274215486"])) # milvusService.delete_doc(KnowledgeFile("README.md", "test")) # milvusService.do_drop_kb() # print(milvusService.search_docs("如何启动api服务")) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 3a6bab0..cf58ce3 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -22,13 +22,12 @@ class PGKBService(KBService): distance_strategy=DistanceStrategy.EUCLIDEAN, connection_string=kbs_config.get("pg").get("connection_uri")) - def get_doc_by_id(self, id: str) -> Optional[Document]: + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: with self.pg_vector.connect() as connect: - stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id") + stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids") results = [Document(page_content=row[0], metadata=row[1]) for row in - connect.execute(stmt, parameters={'id': id}).fetchall()] - if len(results) > 0: - return results[0] + connect.execute(stmt, parameters={'ids': ids}).fetchall()] + return results def do_init(self): self._load_pg_vector() @@ -88,5 +87,5 @@ if __name__ == '__main__': # pGKBService.add_doc(KnowledgeFile("README.md", "test")) # pGKBService.delete_doc(KnowledgeFile("README.md", "test")) # pGKBService.drop_kb() - print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772")) + print(pGKBService.get_doc_by_ids(["f1e51390-3029-4a19-90dc-7118aaa25772"])) # print(pGKBService.search_docs("如何启动api服务")) diff --git a/server/knowledge_base/kb_service/zilliz_kb_service.py b/server/knowledge_base/kb_service/zilliz_kb_service.py index bd8b3e9..d82f873 100644 --- a/server/knowledge_base/kb_service/zilliz_kb_service.py +++ b/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -20,13 +20,14 @@ class ZillizKBService(KBService): # if self.zilliz.col: # self.zilliz.col.flush() - def get_doc_by_id(self, id: str) -> Optional[Document]: + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: + result = [] if self.zilliz.col: - data_list = self.zilliz.col.query(expr=f'pk == {id}', output_fields=["*"]) - if len(data_list) > 0: - data = data_list[0] + data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"]) + for data in data_list: text = data.pop("text") - return Document(page_content=text, metadata=data) + result.append(Document(page_content=text, metadata=data)) + return result @staticmethod def search(zilliz_name, content, limit=3):