diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py index 2ff0a0a..5949afe 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py @@ -32,6 +32,7 @@ from chatchat.server.utils import ( get_default_embedding, ) from chatchat.utils import build_logger +from typing import List, Dict,Tuple logger = build_logger() @@ -73,6 +74,8 @@ def search_docs( docs = kb.search_docs(query, top_k, score_threshold) logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold}") # data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] + docs_key = kb.search_content_internal(query,2) + docs = merge_and_deduplicate(docs, docs_key) data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs] elif file_name or metadata: data = kb.list_docs(file_name=file_name, metadata=metadata) @@ -81,6 +84,27 @@ def search_docs( del d.metadata["vector"] return [x.dict() for x in data] +def merge_and_deduplicate(list1: List[Tuple[Document, float]], list2: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]: + # 使用字典来存储 page_content 和对应的元组 (Document, float) + merged_dict = {} + # 遍历第一个列表 + for item in list1: + document, value = item + page_content = document.page_content + # 如果 page_content 不在字典中,将其加入字典 + if page_content not in merged_dict: + merged_dict[page_content] = item + + # 遍历第二个列表 + for item in list2: + document, value = item + page_content = document.page_content + # 如果 page_content 不在字典中,将其加入字典 + if page_content not in merged_dict: + merged_dict[page_content] = item + + # 将字典的值转换为列表并返回 + return list(merged_dict.values()) def list_files(knowledge_base_name: str) -> ListResponse: if not validate_kb_name(knowledge_base_name): diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py index 7d5c4c4..5be3efb 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py @@ -210,6 +210,12 @@ class KBService(ABC): docs = self.do_search(query, top_k, score_threshold) return docs + def search_content_internal(self, + query: str, + top_k: int, + )->List[Tuple[Document, float]]: + docs = self.searchbyContentInternal(query,top_k) + return docs def get_doc_by_ids(self, ids: List[str]) -> List[Document]: return [] @@ -319,6 +325,16 @@ class KBService(ABC): """ pass + @abstractmethod + def searchbyContentInternal(self, + query: str, + top_k: int, + )->List[Tuple[Document, float]]: + """ + 搜索知识库子类实自己逻辑 + """ + pass + @abstractmethod def do_add_doc( self, diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py index b50c8a0..0d918cb 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py @@ -16,7 +16,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.utils import get_Embeddings from chatchat.utils import build_logger - +from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId logger = build_logger() @@ -143,6 +143,54 @@ class ESKBService(KBService): docs = retriever.get_relevant_documents(query) return docs + def searchbyContent(self, query:str, top_k: int = 2): + if self.es_client_python.indices.exists(index=self.index_name): + logger.info(f"******ESKBService searchByContent {self.index_name},query:{query}") + tem_query = { + "query": {"match": { + "context": "*" + query + "*" + }}, + "highlight":{"fields":{ + "context":{} + }} + } + search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k) + hits = [hit for hit in search_results["hits"]["hits"]] + + docs_and_scores = [] + for hit in hits: + highlighted_contexts = "" + if 'highlight' in hit: + highlighted_contexts = " ".join(hit['highlight']['context']) + #print(f"******searchByContent highlighted_contexts:{highlighted_contexts}") + docs_and_scores.append(DocumentWithVSId( + page_content=highlighted_contexts, + metadata=hit["_source"]["metadata"], + id = hit["_id"], + )) + return docs_and_scores + + def searchbyContentInternal(self, query:str, top_k: int = 2): + if self.es_client_python.indices.exists(index=self.index_name): + logger.info(f"******ESKBService searchbyContentInternal {self.index_name},query:{query}") + tem_query = { + "query": {"match": { + "context": "*" + query + "*" + }} + } + search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k) + hits = [hit for hit in search_results["hits"]["hits"]] + docs_and_scores = [ + ( + Document( + page_content=hit["_source"]["context"], + metadata=hit["_source"]["metadata"], + ), + 1.3, + ) + for hit in hits + ] + return docs_and_scores def get_doc_by_ids(self, ids: List[str]) -> List[Document]: results = [] for doc_id in ids: @@ -179,10 +227,12 @@ class ESKBService(KBService): }, "track_total_hits": True, } - # 注意设置size,默认返回10个,es检索设置track_total_hits为True返回数据库中真实的size。 - size = self.es_client_python.search(body=query)["hits"]["total"]["value"] - search_results = self.es_client_python.search(body=query, size=size) - delete_list = [hit["_id"] for hit in search_results["hits"]["hits"]] + print(f"***do_delete_doc: kb_file.filepath:{kb_file.filepath}, base_file_name:{base_file_name}") + # 注意设置size,默认返回10个。 + search_results = self.es_client_python.search(index=self.index_name, body=query,size=200) + delete_list = [hit["_id"] for hit in search_results['hits']['hits']] + size = len(delete_list) + #print(f"***do_delete_doc: 删除的size:{size}, {delete_list}") if len(delete_list) == 0: return None else: @@ -226,6 +276,8 @@ class ESKBService(KBService): {"id": hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"] ] + #size = len(info_docs) + #print(f"do_add_doc 召回元素个数:{size}") return info_docs def do_clear_vs(self): diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py index 32dae07..e31151a 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py @@ -78,6 +78,12 @@ class FaissKBService(KBService): docs = retriever.get_relevant_documents(query) return docs + def searchbyContent(self): + pass + + def searchbyContentInternal(self): + pass + def do_add_doc( self, docs: List[Document], diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py index ef06c24..4041519 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py @@ -88,6 +88,12 @@ class MilvusKBService(KBService): docs = retriever.get_relevant_documents(query) return docs + def searchbyContent(self): + pass + + def searchbyContentInternal(self): + pass + def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: for doc in docs: for k, v in doc.metadata.items(): diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py index 2f40f13..5e10b6e 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py @@ -84,6 +84,12 @@ class PGKBService(KBService): docs = retriever.get_relevant_documents(query) return docs + def searchbyContent(self): + pass + + def searchbyContentInternal(self): + pass + def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: ids = self.pg_vector.add_documents(docs) doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/relyt_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/relyt_kb_service.py index 0b85102..eccfe31 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/relyt_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/relyt_kb_service.py @@ -94,6 +94,13 @@ class RelytKBService(KBService): docs = self.relyt.similarity_search_with_score(query, top_k) return score_threshold_process(score_threshold, top_k, docs) + + def searchbyContent(self): + pass + + def searchbyContentInternal(self): + pass + def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: print(docs) ids = self.relyt.add_documents(docs) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py index bfa3233..4af3dff 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -79,6 +79,12 @@ class ZillizKBService(KBService): docs = retriever.get_relevant_documents(query) return docs + def searchbyContent(self): + pass + + def searchbyContentInternal(self): + pass + def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: for doc in docs: for k, v in doc.metadata.items():