milvus/pg kb_service需要实现get_doc_by_id方法

This commit is contained in:
zqt 2023-09-04 16:40:05 +08:00
parent 0bc9d5c8ee
commit e21ca447af
2 changed files with 32 additions and 19 deletions

View File

@ -22,9 +22,13 @@ class MilvusKBService(KBService):
from pymilvus import Collection
return Collection(milvus_name)
# TODO:
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None
if self.milvus.col:
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
if len(data_list) > 0:
data = data_list[0]
text = data.pop("text")
return Document(page_content=text, metadata=data)
@staticmethod
def search(milvus_name, content, limit=3):
@ -64,6 +68,7 @@ class MilvusKBService(KBService):
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
if self.milvus.col:
filepath = kb_file.filepath.replace('\\', '\\\\')
delete_list = [item.get("pk") for item in
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
@ -80,7 +85,9 @@ if __name__ == '__main__':
Base.metadata.create_all(bind=engine)
milvusService = MilvusKBService("test")
milvusService.add_doc(KnowledgeFile("README.md", "test"))
milvusService.delete_doc(KnowledgeFile("README.md", "test"))
milvusService.do_drop_kb()
print(milvusService.search_docs("如何启动api服务"))
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
print(milvusService.get_doc_by_id("444022434274215486"))
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
# milvusService.do_drop_kb()
# print(milvusService.search_docs("如何启动api服务"))

View File

@ -1,3 +1,4 @@
import json
from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
@ -13,6 +14,7 @@ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, Em
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
from server.utils import embedding_device as get_embedding_device
class PGKBService(KBService):
pg_vector: PGVector
@ -24,10 +26,13 @@ class PGKBService(KBService):
collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri"))
# TODO:
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None
with self.pg_vector.connect() as connect:
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
results = [Document(page_content=row[0], metadata=row[1]) for row in
connect.execute(stmt, parameters={'id': id}).fetchall()]
if len(results) > 0:
return results[0]
def do_init(self):
self._load_pg_vector()
@ -77,10 +82,11 @@ class PGKBService(KBService):
if __name__ == '__main__':
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
# Base.metadata.create_all(bind=engine)
pGKBService = PGKBService("test")
pGKBService.create_kb()
pGKBService.add_doc(KnowledgeFile("README.md", "test"))
pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
pGKBService.drop_kb()
print(pGKBService.search_docs("如何启动api服务"))
# pGKBService.create_kb()
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
# pGKBService.drop_kb()
print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772"))
# print(pGKBService.search_docs("如何启动api服务"))