Merge pull request #1199 from zqt996/master

修复pg和milvus kbservice代码
This commit is contained in:
zqt996 2023-08-22 16:53:39 +08:00 committed by GitHub
commit d79676cad1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 3 deletions

View File

@ -45,10 +45,10 @@ class MilvusKBService(KBService):
def do_drop_kb(self): def do_drop_kb(self):
self.milvus.col.drop() self.milvus.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: def do_search(self, query: str, top_k: int, embeddings: Embeddings):
# todo: support score threshold # todo: support score threshold
self._load_milvus(embeddings=embeddings) self._load_milvus(embeddings=embeddings)
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD) return self.milvus.similarity_search_with_score(query, top_k)
def add_doc(self, kb_file: KnowledgeFile): def add_doc(self, kb_file: KnowledgeFile):
""" """
@ -76,6 +76,7 @@ class MilvusKBService(KBService):
if __name__ == '__main__': if __name__ == '__main__':
# 测试建表使用 # 测试建表使用
from server.db.base import Base, engine from server.db.base import Base, engine
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
milvusService = MilvusKBService("test") milvusService = MilvusKBService("test")
milvusService.add_doc(KnowledgeFile("README.md", "test")) milvusService.add_doc(KnowledgeFile("README.md", "test"))

View File

@ -43,7 +43,7 @@ class PGKBService(KBService):
''')) '''))
connect.commit() connect.commit()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: def do_search(self, query: str, top_k: int, embeddings: Embeddings):
# todo: support score threshold # todo: support score threshold
self._load_pg_vector(embeddings=embeddings) self._load_pg_vector(embeddings=embeddings)
return self.pg_vector.similarity_search_with_score(query, top_k) return self.pg_vector.similarity_search_with_score(query, top_k)
@ -76,6 +76,7 @@ class PGKBService(KBService):
if __name__ == '__main__': if __name__ == '__main__':
from server.db.base import Base, engine from server.db.base import Base, engine
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
pGKBService = PGKBService("test") pGKBService = PGKBService("test")
pGKBService.create_kb() pGKBService.create_kb()