milvus/pg kb_service需要实现get_doc_by_id方法
This commit is contained in:
parent
0bc9d5c8ee
commit
e21ca447af
|
|
@ -22,9 +22,13 @@ class MilvusKBService(KBService):
|
|||
from pymilvus import Collection
|
||||
return Collection(milvus_name)
|
||||
|
||||
# TODO:
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
return None
|
||||
if self.milvus.col:
|
||||
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
|
||||
if len(data_list) > 0:
|
||||
data = data_list[0]
|
||||
text = data.pop("text")
|
||||
return Document(page_content=text, metadata=data)
|
||||
|
||||
@staticmethod
|
||||
def search(milvus_name, content, limit=3):
|
||||
|
|
@ -64,10 +68,11 @@ class MilvusKBService(KBService):
|
|||
return doc_infos
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
delete_list = [item.get("pk") for item in
|
||||
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
||||
self.milvus.col.delete(expr=f'pk in {delete_list}')
|
||||
if self.milvus.col:
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
delete_list = [item.get("pk") for item in
|
||||
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
||||
self.milvus.col.delete(expr=f'pk in {delete_list}')
|
||||
|
||||
def do_clear_vs(self):
|
||||
if self.milvus.col:
|
||||
|
|
@ -80,7 +85,9 @@ if __name__ == '__main__':
|
|||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
milvusService = MilvusKBService("test")
|
||||
milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
milvusService.do_drop_kb()
|
||||
print(milvusService.search_docs("如何启动api服务"))
|
||||
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
|
||||
print(milvusService.get_doc_by_id("444022434274215486"))
|
||||
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
# milvusService.do_drop_kb()
|
||||
# print(milvusService.search_docs("如何启动api服务"))
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
|
@ -13,6 +14,7 @@ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, Em
|
|||
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
||||
from server.utils import embedding_device as get_embedding_device
|
||||
|
||||
|
||||
class PGKBService(KBService):
|
||||
pg_vector: PGVector
|
||||
|
||||
|
|
@ -24,10 +26,13 @@ class PGKBService(KBService):
|
|||
collection_name=self.kb_name,
|
||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||
|
||||
# TODO:
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
return None
|
||||
with self.pg_vector.connect() as connect:
|
||||
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
|
||||
results = [Document(page_content=row[0], metadata=row[1]) for row in
|
||||
connect.execute(stmt, parameters={'id': id}).fetchall()]
|
||||
if len(results) > 0:
|
||||
return results[0]
|
||||
|
||||
def do_init(self):
|
||||
self._load_pg_vector()
|
||||
|
|
@ -77,10 +82,11 @@ class PGKBService(KBService):
|
|||
if __name__ == '__main__':
|
||||
from server.db.base import Base, engine
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
# Base.metadata.create_all(bind=engine)
|
||||
pGKBService = PGKBService("test")
|
||||
pGKBService.create_kb()
|
||||
pGKBService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
pGKBService.drop_kb()
|
||||
print(pGKBService.search_docs("如何启动api服务"))
|
||||
# pGKBService.create_kb()
|
||||
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
# pGKBService.drop_kb()
|
||||
print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772"))
|
||||
# print(pGKBService.search_docs("如何启动api服务"))
|
||||
|
|
|
|||
Loading…
Reference in New Issue