From 16542f20b4c538982678113fc08a643b1d28d33f Mon Sep 17 00:00:00 2001 From: zqt <1178747941@qq.com> Date: Tue, 22 Aug 2023 16:52:04 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dpg=E5=92=8Cmilvus=20kbservice?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/knowledge_base/kb_service/milvus_kb_service.py | 5 +++-- server/knowledge_base/kb_service/pg_kb_service.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index f9c40c0..87588ed 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -45,10 +45,10 @@ class MilvusKBService(KBService): def do_drop_kb(self): self.milvus.col.drop() - def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: + def do_search(self, query: str, top_k: int, embeddings: Embeddings): # todo: support score threshold self._load_milvus(embeddings=embeddings) - return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD) + return self.milvus.similarity_search_with_score(query, top_k) def add_doc(self, kb_file: KnowledgeFile): """ @@ -76,6 +76,7 @@ class MilvusKBService(KBService): if __name__ == '__main__': # 测试建表使用 from server.db.base import Base, engine + Base.metadata.create_all(bind=engine) milvusService = MilvusKBService("test") milvusService.add_doc(KnowledgeFile("README.md", "test")) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index a3f9318..d51a113 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -43,7 +43,7 @@ class PGKBService(KBService): ''')) connect.commit() - def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: + def do_search(self, query: str, top_k: int, embeddings: Embeddings): # todo: support score threshold self._load_pg_vector(embeddings=embeddings) return self.pg_vector.similarity_search_with_score(query, top_k) @@ -76,6 +76,7 @@ class PGKBService(KBService): if __name__ == '__main__': from server.db.base import Base, engine + Base.metadata.create_all(bind=engine) pGKBService = PGKBService("test") pGKBService.create_kb()