diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index f9c40c0..87588ed 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -45,10 +45,10 @@ class MilvusKBService(KBService): def do_drop_kb(self): self.milvus.col.drop() - def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: + def do_search(self, query: str, top_k: int, embeddings: Embeddings): # todo: support score threshold self._load_milvus(embeddings=embeddings) - return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD) + return self.milvus.similarity_search_with_score(query, top_k) def add_doc(self, kb_file: KnowledgeFile): """ @@ -76,6 +76,7 @@ class MilvusKBService(KBService): if __name__ == '__main__': # 测试建表使用 from server.db.base import Base, engine + Base.metadata.create_all(bind=engine) milvusService = MilvusKBService("test") milvusService.add_doc(KnowledgeFile("README.md", "test")) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index a3f9318..d51a113 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -43,7 +43,7 @@ class PGKBService(KBService): ''')) connect.commit() - def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: + def do_search(self, query: str, top_k: int, embeddings: Embeddings): # todo: support score threshold self._load_pg_vector(embeddings=embeddings) return self.pg_vector.similarity_search_with_score(query, top_k) @@ -76,6 +76,7 @@ class PGKBService(KBService): if __name__ == '__main__': from server.db.base import Base, engine + Base.metadata.create_all(bind=engine) pGKBService = PGKBService("test") pGKBService.create_kb()