From e21ca447af8993af043d6d4b6da5a513a9144de8 Mon Sep 17 00:00:00 2001 From: zqt <1178747941@qq.com> Date: Mon, 4 Sep 2023 16:40:05 +0800 Subject: [PATCH] =?UTF-8?q?milvus/pg=20kb=5Fservice=E9=9C=80=E8=A6=81?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0get=5Fdoc=5Fby=5Fid=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kb_service/milvus_kb_service.py | 27 ++++++++++++------- .../kb_service/pg_kb_service.py | 24 ++++++++++------- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 9819713..444765f 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -22,9 +22,13 @@ class MilvusKBService(KBService): from pymilvus import Collection return Collection(milvus_name) - # TODO: def get_doc_by_id(self, id: str) -> Optional[Document]: - return None + if self.milvus.col: + data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"]) + if len(data_list) > 0: + data = data_list[0] + text = data.pop("text") + return Document(page_content=text, metadata=data) @staticmethod def search(milvus_name, content, limit=3): @@ -64,10 +68,11 @@ class MilvusKBService(KBService): return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): - filepath = kb_file.filepath.replace('\\', '\\\\') - delete_list = [item.get("pk") for item in - self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])] - self.milvus.col.delete(expr=f'pk in {delete_list}') + if self.milvus.col: + filepath = kb_file.filepath.replace('\\', '\\\\') + delete_list = [item.get("pk") for item in + self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])] + self.milvus.col.delete(expr=f'pk in {delete_list}') def do_clear_vs(self): if self.milvus.col: @@ -80,7 +85,9 @@ if __name__ == '__main__': Base.metadata.create_all(bind=engine) milvusService = MilvusKBService("test") - milvusService.add_doc(KnowledgeFile("README.md", "test")) - milvusService.delete_doc(KnowledgeFile("README.md", "test")) - milvusService.do_drop_kb() - print(milvusService.search_docs("如何启动api服务")) + # milvusService.add_doc(KnowledgeFile("README.md", "test")) + + print(milvusService.get_doc_by_id("444022434274215486")) + # milvusService.delete_doc(KnowledgeFile("README.md", "test")) + # milvusService.do_drop_kb() + # print(milvusService.search_docs("如何启动api服务")) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index afe9f45..e6381fa 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -1,3 +1,4 @@ +import json from typing import List, Dict, Optional from langchain.embeddings.base import Embeddings @@ -13,6 +14,7 @@ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, Em from server.knowledge_base.utils import load_embeddings, KnowledgeFile from server.utils import embedding_device as get_embedding_device + class PGKBService(KBService): pg_vector: PGVector @@ -24,10 +26,13 @@ class PGKBService(KBService): collection_name=self.kb_name, distance_strategy=DistanceStrategy.EUCLIDEAN, connection_string=kbs_config.get("pg").get("connection_uri")) - - # TODO: def get_doc_by_id(self, id: str) -> Optional[Document]: - return None + with self.pg_vector.connect() as connect: + stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id") + results = [Document(page_content=row[0], metadata=row[1]) for row in + connect.execute(stmt, parameters={'id': id}).fetchall()] + if len(results) > 0: + return results[0] def do_init(self): self._load_pg_vector() @@ -77,10 +82,11 @@ class PGKBService(KBService): if __name__ == '__main__': from server.db.base import Base, engine - Base.metadata.create_all(bind=engine) + # Base.metadata.create_all(bind=engine) pGKBService = PGKBService("test") - pGKBService.create_kb() - pGKBService.add_doc(KnowledgeFile("README.md", "test")) - pGKBService.delete_doc(KnowledgeFile("README.md", "test")) - pGKBService.drop_kb() - print(pGKBService.search_docs("如何启动api服务")) + # pGKBService.create_kb() + # pGKBService.add_doc(KnowledgeFile("README.md", "test")) + # pGKBService.delete_doc(KnowledgeFile("README.md", "test")) + # pGKBService.drop_kb() + print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772")) + # print(pGKBService.search_docs("如何启动api服务"))