milvus/pg kb_service需要实现get_doc_by_id方法
This commit is contained in:
parent
0bc9d5c8ee
commit
e21ca447af
|
|
@ -22,9 +22,13 @@ class MilvusKBService(KBService):
|
||||||
from pymilvus import Collection
|
from pymilvus import Collection
|
||||||
return Collection(milvus_name)
|
return Collection(milvus_name)
|
||||||
|
|
||||||
# TODO:
|
|
||||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||||
return None
|
if self.milvus.col:
|
||||||
|
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
|
||||||
|
if len(data_list) > 0:
|
||||||
|
data = data_list[0]
|
||||||
|
text = data.pop("text")
|
||||||
|
return Document(page_content=text, metadata=data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def search(milvus_name, content, limit=3):
|
def search(milvus_name, content, limit=3):
|
||||||
|
|
@ -64,10 +68,11 @@ class MilvusKBService(KBService):
|
||||||
return doc_infos
|
return doc_infos
|
||||||
|
|
||||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
if self.milvus.col:
|
||||||
delete_list = [item.get("pk") for item in
|
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||||
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
delete_list = [item.get("pk") for item in
|
||||||
self.milvus.col.delete(expr=f'pk in {delete_list}')
|
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
||||||
|
self.milvus.col.delete(expr=f'pk in {delete_list}')
|
||||||
|
|
||||||
def do_clear_vs(self):
|
def do_clear_vs(self):
|
||||||
if self.milvus.col:
|
if self.milvus.col:
|
||||||
|
|
@ -80,7 +85,9 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
Base.metadata.create_all(bind=engine)
|
Base.metadata.create_all(bind=engine)
|
||||||
milvusService = MilvusKBService("test")
|
milvusService = MilvusKBService("test")
|
||||||
milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||||
milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
|
||||||
milvusService.do_drop_kb()
|
print(milvusService.get_doc_by_id("444022434274215486"))
|
||||||
print(milvusService.search_docs("如何启动api服务"))
|
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||||
|
# milvusService.do_drop_kb()
|
||||||
|
# print(milvusService.search_docs("如何启动api服务"))
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import json
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
@ -13,6 +14,7 @@ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, Em
|
||||||
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
||||||
from server.utils import embedding_device as get_embedding_device
|
from server.utils import embedding_device as get_embedding_device
|
||||||
|
|
||||||
|
|
||||||
class PGKBService(KBService):
|
class PGKBService(KBService):
|
||||||
pg_vector: PGVector
|
pg_vector: PGVector
|
||||||
|
|
||||||
|
|
@ -24,10 +26,13 @@ class PGKBService(KBService):
|
||||||
collection_name=self.kb_name,
|
collection_name=self.kb_name,
|
||||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||||
|
|
||||||
# TODO:
|
|
||||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||||
return None
|
with self.pg_vector.connect() as connect:
|
||||||
|
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
|
||||||
|
results = [Document(page_content=row[0], metadata=row[1]) for row in
|
||||||
|
connect.execute(stmt, parameters={'id': id}).fetchall()]
|
||||||
|
if len(results) > 0:
|
||||||
|
return results[0]
|
||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
self._load_pg_vector()
|
self._load_pg_vector()
|
||||||
|
|
@ -77,10 +82,11 @@ class PGKBService(KBService):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from server.db.base import Base, engine
|
from server.db.base import Base, engine
|
||||||
|
|
||||||
Base.metadata.create_all(bind=engine)
|
# Base.metadata.create_all(bind=engine)
|
||||||
pGKBService = PGKBService("test")
|
pGKBService = PGKBService("test")
|
||||||
pGKBService.create_kb()
|
# pGKBService.create_kb()
|
||||||
pGKBService.add_doc(KnowledgeFile("README.md", "test"))
|
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
|
||||||
pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
|
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||||
pGKBService.drop_kb()
|
# pGKBService.drop_kb()
|
||||||
print(pGKBService.search_docs("如何启动api服务"))
|
print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772"))
|
||||||
|
# print(pGKBService.search_docs("如何启动api服务"))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue