增加关键词检索并将向量库切换到es
This commit is contained in:
parent
afa51bb298
commit
08cb4b8750
|
|
@ -32,6 +32,7 @@ from chatchat.server.utils import (
|
||||||
get_default_embedding,
|
get_default_embedding,
|
||||||
)
|
)
|
||||||
from chatchat.utils import build_logger
|
from chatchat.utils import build_logger
|
||||||
|
from typing import List, Dict,Tuple
|
||||||
|
|
||||||
logger = build_logger()
|
logger = build_logger()
|
||||||
|
|
||||||
|
|
@ -73,6 +74,8 @@ def search_docs(
|
||||||
docs = kb.search_docs(query, top_k, score_threshold)
|
docs = kb.search_docs(query, top_k, score_threshold)
|
||||||
logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold}")
|
logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold}")
|
||||||
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
||||||
|
docs_key = kb.search_content_internal(query,2)
|
||||||
|
docs = merge_and_deduplicate(docs, docs_key)
|
||||||
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
|
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
|
||||||
elif file_name or metadata:
|
elif file_name or metadata:
|
||||||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||||||
|
|
@ -81,6 +84,27 @@ def search_docs(
|
||||||
del d.metadata["vector"]
|
del d.metadata["vector"]
|
||||||
return [x.dict() for x in data]
|
return [x.dict() for x in data]
|
||||||
|
|
||||||
|
def merge_and_deduplicate(list1: List[Tuple[Document, float]], list2: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
|
||||||
|
# 使用字典来存储 page_content 和对应的元组 (Document, float)
|
||||||
|
merged_dict = {}
|
||||||
|
# 遍历第一个列表
|
||||||
|
for item in list1:
|
||||||
|
document, value = item
|
||||||
|
page_content = document.page_content
|
||||||
|
# 如果 page_content 不在字典中,将其加入字典
|
||||||
|
if page_content not in merged_dict:
|
||||||
|
merged_dict[page_content] = item
|
||||||
|
|
||||||
|
# 遍历第二个列表
|
||||||
|
for item in list2:
|
||||||
|
document, value = item
|
||||||
|
page_content = document.page_content
|
||||||
|
# 如果 page_content 不在字典中,将其加入字典
|
||||||
|
if page_content not in merged_dict:
|
||||||
|
merged_dict[page_content] = item
|
||||||
|
|
||||||
|
# 将字典的值转换为列表并返回
|
||||||
|
return list(merged_dict.values())
|
||||||
|
|
||||||
def list_files(knowledge_base_name: str) -> ListResponse:
|
def list_files(knowledge_base_name: str) -> ListResponse:
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
|
|
|
||||||
|
|
@ -210,6 +210,12 @@ class KBService(ABC):
|
||||||
docs = self.do_search(query, top_k, score_threshold)
|
docs = self.do_search(query, top_k, score_threshold)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def search_content_internal(self,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
)->List[Tuple[Document, float]]:
|
||||||
|
docs = self.searchbyContentInternal(query,top_k)
|
||||||
|
return docs
|
||||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -319,6 +325,16 @@ class KBService(ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def searchbyContentInternal(self,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
)->List[Tuple[Document, float]]:
|
||||||
|
"""
|
||||||
|
搜索知识库子类实自己逻辑
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def do_add_doc(
|
def do_add_doc(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
|
||||||
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
||||||
from chatchat.server.utils import get_Embeddings
|
from chatchat.server.utils import get_Embeddings
|
||||||
from chatchat.utils import build_logger
|
from chatchat.utils import build_logger
|
||||||
|
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
|
|
||||||
logger = build_logger()
|
logger = build_logger()
|
||||||
|
|
||||||
|
|
@ -143,6 +143,54 @@ class ESKBService(KBService):
|
||||||
docs = retriever.get_relevant_documents(query)
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def searchbyContent(self, query:str, top_k: int = 2):
|
||||||
|
if self.es_client_python.indices.exists(index=self.index_name):
|
||||||
|
logger.info(f"******ESKBService searchByContent {self.index_name},query:{query}")
|
||||||
|
tem_query = {
|
||||||
|
"query": {"match": {
|
||||||
|
"context": "*" + query + "*"
|
||||||
|
}},
|
||||||
|
"highlight":{"fields":{
|
||||||
|
"context":{}
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k)
|
||||||
|
hits = [hit for hit in search_results["hits"]["hits"]]
|
||||||
|
|
||||||
|
docs_and_scores = []
|
||||||
|
for hit in hits:
|
||||||
|
highlighted_contexts = ""
|
||||||
|
if 'highlight' in hit:
|
||||||
|
highlighted_contexts = " ".join(hit['highlight']['context'])
|
||||||
|
#print(f"******searchByContent highlighted_contexts:{highlighted_contexts}")
|
||||||
|
docs_and_scores.append(DocumentWithVSId(
|
||||||
|
page_content=highlighted_contexts,
|
||||||
|
metadata=hit["_source"]["metadata"],
|
||||||
|
id = hit["_id"],
|
||||||
|
))
|
||||||
|
return docs_and_scores
|
||||||
|
|
||||||
|
def searchbyContentInternal(self, query:str, top_k: int = 2):
|
||||||
|
if self.es_client_python.indices.exists(index=self.index_name):
|
||||||
|
logger.info(f"******ESKBService searchbyContentInternal {self.index_name},query:{query}")
|
||||||
|
tem_query = {
|
||||||
|
"query": {"match": {
|
||||||
|
"context": "*" + query + "*"
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k)
|
||||||
|
hits = [hit for hit in search_results["hits"]["hits"]]
|
||||||
|
docs_and_scores = [
|
||||||
|
(
|
||||||
|
Document(
|
||||||
|
page_content=hit["_source"]["context"],
|
||||||
|
metadata=hit["_source"]["metadata"],
|
||||||
|
),
|
||||||
|
1.3,
|
||||||
|
)
|
||||||
|
for hit in hits
|
||||||
|
]
|
||||||
|
return docs_and_scores
|
||||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
results = []
|
results = []
|
||||||
for doc_id in ids:
|
for doc_id in ids:
|
||||||
|
|
@ -179,10 +227,12 @@ class ESKBService(KBService):
|
||||||
},
|
},
|
||||||
"track_total_hits": True,
|
"track_total_hits": True,
|
||||||
}
|
}
|
||||||
# 注意设置size,默认返回10个,es检索设置track_total_hits为True返回数据库中真实的size。
|
print(f"***do_delete_doc: kb_file.filepath:{kb_file.filepath}, base_file_name:{base_file_name}")
|
||||||
size = self.es_client_python.search(body=query)["hits"]["total"]["value"]
|
# 注意设置size,默认返回10个。
|
||||||
search_results = self.es_client_python.search(body=query, size=size)
|
search_results = self.es_client_python.search(index=self.index_name, body=query,size=200)
|
||||||
delete_list = [hit["_id"] for hit in search_results["hits"]["hits"]]
|
delete_list = [hit["_id"] for hit in search_results['hits']['hits']]
|
||||||
|
size = len(delete_list)
|
||||||
|
#print(f"***do_delete_doc: 删除的size:{size}, {delete_list}")
|
||||||
if len(delete_list) == 0:
|
if len(delete_list) == 0:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
|
|
@ -226,6 +276,8 @@ class ESKBService(KBService):
|
||||||
{"id": hit["_id"], "metadata": hit["_source"]["metadata"]}
|
{"id": hit["_id"], "metadata": hit["_source"]["metadata"]}
|
||||||
for hit in search_results["hits"]["hits"]
|
for hit in search_results["hits"]["hits"]
|
||||||
]
|
]
|
||||||
|
#size = len(info_docs)
|
||||||
|
#print(f"do_add_doc 召回元素个数:{size}")
|
||||||
return info_docs
|
return info_docs
|
||||||
|
|
||||||
def do_clear_vs(self):
|
def do_clear_vs(self):
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,12 @@ class FaissKBService(KBService):
|
||||||
docs = retriever.get_relevant_documents(query)
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def searchbyContent(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def searchbyContentInternal(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def do_add_doc(
|
def do_add_doc(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,12 @@ class MilvusKBService(KBService):
|
||||||
docs = retriever.get_relevant_documents(query)
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def searchbyContent(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def searchbyContentInternal(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
for k, v in doc.metadata.items():
|
for k, v in doc.metadata.items():
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,12 @@ class PGKBService(KBService):
|
||||||
docs = retriever.get_relevant_documents(query)
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def searchbyContent(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def searchbyContentInternal(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
ids = self.pg_vector.add_documents(docs)
|
ids = self.pg_vector.add_documents(docs)
|
||||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||||
|
|
|
||||||
|
|
@ -94,6 +94,13 @@ class RelytKBService(KBService):
|
||||||
docs = self.relyt.similarity_search_with_score(query, top_k)
|
docs = self.relyt.similarity_search_with_score(query, top_k)
|
||||||
return score_threshold_process(score_threshold, top_k, docs)
|
return score_threshold_process(score_threshold, top_k, docs)
|
||||||
|
|
||||||
|
|
||||||
|
def searchbyContent(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def searchbyContentInternal(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
print(docs)
|
print(docs)
|
||||||
ids = self.relyt.add_documents(docs)
|
ids = self.relyt.add_documents(docs)
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,12 @@ class ZillizKBService(KBService):
|
||||||
docs = retriever.get_relevant_documents(query)
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def searchbyContent(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def searchbyContentInternal(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
for k, v in doc.metadata.items():
|
for k, v in doc.metadata.items():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue