增加关键词检索并将向量库切换到es

This commit is contained in:
weiweiw 2025-01-20 15:30:05 +08:00
parent afa51bb298
commit 08cb4b8750
8 changed files with 128 additions and 5 deletions

View File

@ -32,6 +32,7 @@ from chatchat.server.utils import (
get_default_embedding, get_default_embedding,
) )
from chatchat.utils import build_logger from chatchat.utils import build_logger
from typing import List, Dict,Tuple
logger = build_logger() logger = build_logger()
@ -73,6 +74,8 @@ def search_docs(
docs = kb.search_docs(query, top_k, score_threshold) docs = kb.search_docs(query, top_k, score_threshold)
logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold}") logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold}")
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] # data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
docs_key = kb.search_content_internal(query,2)
docs = merge_and_deduplicate(docs, docs_key)
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs] data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
elif file_name or metadata: elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata) data = kb.list_docs(file_name=file_name, metadata=metadata)
@ -81,6 +84,27 @@ def search_docs(
del d.metadata["vector"] del d.metadata["vector"]
return [x.dict() for x in data] return [x.dict() for x in data]
def merge_and_deduplicate(list1: List[Tuple[Document, float]], list2: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
# 使用字典来存储 page_content 和对应的元组 (Document, float)
merged_dict = {}
# 遍历第一个列表
for item in list1:
document, value = item
page_content = document.page_content
# 如果 page_content 不在字典中,将其加入字典
if page_content not in merged_dict:
merged_dict[page_content] = item
# 遍历第二个列表
for item in list2:
document, value = item
page_content = document.page_content
# 如果 page_content 不在字典中,将其加入字典
if page_content not in merged_dict:
merged_dict[page_content] = item
# 将字典的值转换为列表并返回
return list(merged_dict.values())
def list_files(knowledge_base_name: str) -> ListResponse: def list_files(knowledge_base_name: str) -> ListResponse:
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):

View File

@ -210,6 +210,12 @@ class KBService(ABC):
docs = self.do_search(query, top_k, score_threshold) docs = self.do_search(query, top_k, score_threshold)
return docs return docs
def search_content_internal(self,
query: str,
top_k: int,
)->List[Tuple[Document, float]]:
docs = self.searchbyContentInternal(query,top_k)
return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
return [] return []
@ -319,6 +325,16 @@ class KBService(ABC):
""" """
pass pass
@abstractmethod
def searchbyContentInternal(self,
query: str,
top_k: int,
)->List[Tuple[Document, float]]:
"""
搜索知识库子类实自己逻辑
"""
pass
@abstractmethod @abstractmethod
def do_add_doc( def do_add_doc(
self, self,

View File

@ -16,7 +16,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings from chatchat.server.utils import get_Embeddings
from chatchat.utils import build_logger from chatchat.utils import build_logger
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
logger = build_logger() logger = build_logger()
@ -143,6 +143,54 @@ class ESKBService(KBService):
docs = retriever.get_relevant_documents(query) docs = retriever.get_relevant_documents(query)
return docs return docs
def searchbyContent(self, query:str, top_k: int = 2):
if self.es_client_python.indices.exists(index=self.index_name):
logger.info(f"******ESKBService searchByContent {self.index_name},query:{query}")
tem_query = {
"query": {"match": {
"context": "*" + query + "*"
}},
"highlight":{"fields":{
"context":{}
}}
}
search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k)
hits = [hit for hit in search_results["hits"]["hits"]]
docs_and_scores = []
for hit in hits:
highlighted_contexts = ""
if 'highlight' in hit:
highlighted_contexts = " ".join(hit['highlight']['context'])
#print(f"******searchByContent highlighted_contexts:{highlighted_contexts}")
docs_and_scores.append(DocumentWithVSId(
page_content=highlighted_contexts,
metadata=hit["_source"]["metadata"],
id = hit["_id"],
))
return docs_and_scores
def searchbyContentInternal(self, query:str, top_k: int = 2):
if self.es_client_python.indices.exists(index=self.index_name):
logger.info(f"******ESKBService searchbyContentInternal {self.index_name},query:{query}")
tem_query = {
"query": {"match": {
"context": "*" + query + "*"
}}
}
search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k)
hits = [hit for hit in search_results["hits"]["hits"]]
docs_and_scores = [
(
Document(
page_content=hit["_source"]["context"],
metadata=hit["_source"]["metadata"],
),
1.3,
)
for hit in hits
]
return docs_and_scores
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
results = [] results = []
for doc_id in ids: for doc_id in ids:
@ -179,10 +227,12 @@ class ESKBService(KBService):
}, },
"track_total_hits": True, "track_total_hits": True,
} }
# 注意设置size默认返回10个es检索设置track_total_hits为True返回数据库中真实的size。 print(f"***do_delete_doc: kb_file.filepath:{kb_file.filepath}, base_file_name:{base_file_name}")
size = self.es_client_python.search(body=query)["hits"]["total"]["value"] # 注意设置size默认返回10个。
search_results = self.es_client_python.search(body=query, size=size) search_results = self.es_client_python.search(index=self.index_name, body=query,size=200)
delete_list = [hit["_id"] for hit in search_results["hits"]["hits"]] delete_list = [hit["_id"] for hit in search_results['hits']['hits']]
size = len(delete_list)
#print(f"***do_delete_doc: 删除的size:{size}, {delete_list}")
if len(delete_list) == 0: if len(delete_list) == 0:
return None return None
else: else:
@ -226,6 +276,8 @@ class ESKBService(KBService):
{"id": hit["_id"], "metadata": hit["_source"]["metadata"]} {"id": hit["_id"], "metadata": hit["_source"]["metadata"]}
for hit in search_results["hits"]["hits"] for hit in search_results["hits"]["hits"]
] ]
#size = len(info_docs)
#print(f"do_add_doc 召回元素个数:{size}")
return info_docs return info_docs
def do_clear_vs(self): def do_clear_vs(self):

View File

@ -78,6 +78,12 @@ class FaissKBService(KBService):
docs = retriever.get_relevant_documents(query) docs = retriever.get_relevant_documents(query)
return docs return docs
def searchbyContent(self):
pass
def searchbyContentInternal(self):
pass
def do_add_doc( def do_add_doc(
self, self,
docs: List[Document], docs: List[Document],

View File

@ -88,6 +88,12 @@ class MilvusKBService(KBService):
docs = retriever.get_relevant_documents(query) docs = retriever.get_relevant_documents(query)
return docs return docs
def searchbyContent(self):
pass
def searchbyContentInternal(self):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs: for doc in docs:
for k, v in doc.metadata.items(): for k, v in doc.metadata.items():

View File

@ -84,6 +84,12 @@ class PGKBService(KBService):
docs = retriever.get_relevant_documents(query) docs = retriever.get_relevant_documents(query)
return docs return docs
def searchbyContent(self):
pass
def searchbyContentInternal(self):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs) ids = self.pg_vector.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]

View File

@ -94,6 +94,13 @@ class RelytKBService(KBService):
docs = self.relyt.similarity_search_with_score(query, top_k) docs = self.relyt.similarity_search_with_score(query, top_k)
return score_threshold_process(score_threshold, top_k, docs) return score_threshold_process(score_threshold, top_k, docs)
def searchbyContent(self):
pass
def searchbyContentInternal(self):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
print(docs) print(docs)
ids = self.relyt.add_documents(docs) ids = self.relyt.add_documents(docs)

View File

@ -79,6 +79,12 @@ class ZillizKBService(KBService):
docs = retriever.get_relevant_documents(query) docs = retriever.get_relevant_documents(query)
return docs return docs
def searchbyContent(self):
pass
def searchbyContentInternal(self):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs: for doc in docs:
for k, v in doc.metadata.items(): for k, v in doc.metadata.items():