Compare commits

...

2 Commits

Author SHA1 Message Date
weiweiw 7f57a4ede8 增加关键词检索和es向量库检索 2025-01-21 12:20:39 +08:00
weiweiw 08cb4b8750 增加关键词检索并将向量库切换到es 2025-01-20 15:30:05 +08:00
9 changed files with 150 additions and 10 deletions

View File

@ -22,7 +22,7 @@ from chatchat.server.utils import (wrap_done, get_ChatOpenAI, get_default_llm,
BaseResponse, get_prompt_template, build_logger,
check_embed_model, api_address
)
import time
logger = build_logger()
@ -60,6 +60,8 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
return_direct: bool = Body(False, description="直接返回检索结果,不送入 LLM"),
request: Request = None,
):
logger.info(f"kb_chat:,mode {mode}")
start_time = time.time()
if mode == "local_kb":
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb is None:
@ -67,6 +69,8 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
try:
logger.info(f"***********************************knowledge_base_chat_iterator:,mode {mode}")
start_time1 = time.time()
nonlocal history, prompt_name, max_tokens
history = [History.from_data(h) for h in history]
@ -74,8 +78,10 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
if mode == "local_kb":
kb = KBServiceFactory.get_service_by_name(kb_name)
ok, msg = kb.check_embed_model()
logger.info(f"***********************************knowledge_base_chat_iterator:,mode {mode}kb_name{kb_name}")
if not ok:
raise ValueError(msg)
# docs = search_docs( query = query,knowledge_base_name = kb_name,top_k = top_k, score_threshold = score_threshold,)
docs = await run_in_threadpool(search_docs,
query=query,
knowledge_base_name=kb_name,
@ -83,7 +89,13 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
score_threshold=score_threshold,
file_name="",
metadata={})
source_documents = format_reference(kb_name, docs, api_address(is_public=True))
logger.info(
f"***********************************knowledge_base_chat_iterator:,after format_reference:{docs}")
end_time1 = time.time()
execution_time1 = end_time1 - start_time1
logger.info(f"kb_chat Execution time检索完成: {execution_time1:.6f} seconds")
elif mode == "temp_kb":
ok, msg = check_embed_model()
if not ok:
@ -139,6 +151,7 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
if max_tokens in [None, 0]:
max_tokens = Settings.model_settings.MAX_TOKENS
start_time1 = time.time()
llm = get_ChatOpenAI(
model_name=model,
temperature=temperature,
@ -223,6 +236,12 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
return
if stream:
return EventSourceResponse(knowledge_base_chat_iterator())
eventSource = EventSourceResponse(knowledge_base_chat_iterator())
# 记录结束时间
end_time = time.time()
# 计算执行时间
execution_time = end_time - start_time
logger.info(f"final kb_chat Execution time: {execution_time:.6f} seconds")
return eventSource
else:
return await knowledge_base_chat_iterator().__anext__()

View File

@ -32,6 +32,7 @@ from chatchat.server.utils import (
get_default_embedding,
)
from chatchat.utils import build_logger
from typing import List, Dict,Tuple
logger = build_logger()
@ -71,8 +72,11 @@ def search_docs(
if kb is not None:
if query:
docs = kb.search_docs(query, top_k, score_threshold)
logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold}")
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
# logger.info(f"search_docs, query:{query},top_k:{top_k},len(docs):{len(docs)},docs:{docs}")
docs_key = kb.search_content_internal(query,2)
# logger.info(f"before merge_and_deduplicate docs_key:{docs_key}")
docs = merge_and_deduplicate(docs, docs_key)
logger.info(f"after merge_and_deduplicate docs:{docs}")
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata)
@ -81,6 +85,17 @@ def search_docs(
del d.metadata["vector"]
return [x.dict() for x in data]
def merge_and_deduplicate(list1: List[Document], list2: List[Document]) -> List[Document]:
# 使用字典存储唯一的 Document
merged_dict = {doc.page_content: doc for doc in list1}
# 遍历 list2将新的 Document 添加到字典
for doc in list2:
if doc.page_content not in merged_dict:
merged_dict[doc.page_content] = doc
# 返回去重后的列表
return list(merged_dict.values())
def list_files(knowledge_base_name: str) -> ListResponse:
if not validate_kb_name(knowledge_base_name):

View File

@ -210,6 +210,12 @@ class KBService(ABC):
docs = self.do_search(query, top_k, score_threshold)
return docs
def search_content_internal(self,
query: str,
top_k: int,
)->List[Document]:
docs = self.searchbyContentInternal(query,top_k)
return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
return []
@ -319,6 +325,16 @@ class KBService(ABC):
"""
pass
@abstractmethod
def searchbyContentInternal(self,
query: str,
top_k: int,
)->List[Tuple[Document, float]]:
"""
搜索知识库子类实自己逻辑
"""
pass
@abstractmethod
def do_add_doc(
self,

View File

@ -16,7 +16,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings
from chatchat.utils import build_logger
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
logger = build_logger()
@ -37,9 +37,12 @@ class ESKBService(KBService):
self.client_cert = kb_config.get("client_cert", None)
self.dims_length = kb_config.get("dims_length", None)
self.embeddings_model = get_Embeddings(self.embed_model)
logger.info(f"self.kb_path:{self.kb_path },self.index_name:{self.index_name}, self.scheme:{self.scheme},self.IP:{self.IP},"
f"self.PORT:{self.PORT},self.user:{self.user},self.password:{self.password},self.verify_certs:{self.verify_certs},"
f"self.client_cert:{self.client_cert},self.client_key:{self.client_key},self.dims_length:{self.dims_length}")
try:
connection_info = dict(
host=f"{self.scheme}://{self.IP}:{self.PORT}"
hosts=f"{self.scheme}://{self.IP}:{self.PORT}"
)
if self.user != "" and self.password != "":
connection_info.update(basic_auth=(self.user, self.password))
@ -53,7 +56,9 @@ class ESKBService(KBService):
connection_info.update(client_key=self.client_key)
connection_info.update(client_cert=self.client_cert)
# ES python客户端连接仅连接
logger.info(f"connection_info:{connection_info}")
self.es_client_python = Elasticsearch(**connection_info)
# logger.info(f"after Elasticsearch connection_info:{connection_info}")
except ConnectionError:
logger.error("连接到 Elasticsearch 失败!")
raise ConnectionError
@ -143,6 +148,56 @@ class ESKBService(KBService):
docs = retriever.get_relevant_documents(query)
return docs
def searchbyContent(self, query:str, top_k: int = 2):
if self.es_client_python.indices.exists(index=self.index_name):
logger.info(f"******ESKBService searchByContent {self.index_name},query:{query}")
tem_query = {
"query": {"match": {
"context": "*" + query + "*"
}},
"highlight":{"fields":{
"context":{}
}}
}
search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k)
hits = [hit for hit in search_results["hits"]["hits"]]
docs_and_scores = []
for hit in hits:
highlighted_contexts = ""
if 'highlight' in hit:
highlighted_contexts = " ".join(hit['highlight']['context'])
#print(f"******searchByContent highlighted_contexts:{highlighted_contexts}")
docs_and_scores.append(DocumentWithVSId(
page_content=highlighted_contexts,
metadata=hit["_source"]["metadata"],
id = hit["_id"],
))
return docs_and_scores
def searchbyContentInternal(self, query:str, top_k: int = 2):
if self.es_client_python.indices.exists(index=self.index_name):
logger.info(f"******ESKBService searchbyContentInternal {self.index_name},query:{query}")
tem_query = {
"query": {"match": {
"context": "*" + query + "*"
}}
}
search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k)
hits = [hit for hit in search_results["hits"]["hits"]]
docs_and_scores = [
# (
Document(
page_content=hit["_source"]["context"],
metadata=hit["_source"]["metadata"],
)
# ,
# 1.3,
# )
for hit in hits
]
# logger.info(f"docs_and_scores:{docs_and_scores}")
return docs_and_scores
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
results = []
for doc_id in ids:
@ -179,10 +234,12 @@ class ESKBService(KBService):
},
"track_total_hits": True,
}
# 注意设置size默认返回10个es检索设置track_total_hits为True返回数据库中真实的size。
size = self.es_client_python.search(body=query)["hits"]["total"]["value"]
search_results = self.es_client_python.search(body=query, size=size)
delete_list = [hit["_id"] for hit in search_results["hits"]["hits"]]
print(f"***do_delete_doc: kb_file.filepath:{kb_file.filepath}")
# 注意设置size默认返回10个。
search_results = self.es_client_python.search(index=self.index_name, body=query,size=200)
delete_list = [hit["_id"] for hit in search_results['hits']['hits']]
size = len(delete_list)
#print(f"***do_delete_doc: 删除的size:{size}, {delete_list}")
if len(delete_list) == 0:
return None
else:
@ -226,6 +283,8 @@ class ESKBService(KBService):
{"id": hit["_id"], "metadata": hit["_source"]["metadata"]}
for hit in search_results["hits"]["hits"]
]
#size = len(info_docs)
#print(f"do_add_doc 召回元素个数:{size}")
return info_docs
def do_clear_vs(self):

View File

@ -78,6 +78,12 @@ class FaissKBService(KBService):
docs = retriever.get_relevant_documents(query)
return docs
def searchbyContent(self):
pass
def searchbyContentInternal(self):
pass
def do_add_doc(
self,
docs: List[Document],

View File

@ -88,6 +88,12 @@ class MilvusKBService(KBService):
docs = retriever.get_relevant_documents(query)
return docs
def searchbyContent(self):
pass
def searchbyContentInternal(self):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs:
for k, v in doc.metadata.items():

View File

@ -84,6 +84,12 @@ class PGKBService(KBService):
docs = retriever.get_relevant_documents(query)
return docs
def searchbyContent(self):
pass
def searchbyContentInternal(self):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]

View File

@ -94,6 +94,13 @@ class RelytKBService(KBService):
docs = self.relyt.similarity_search_with_score(query, top_k)
return score_threshold_process(score_threshold, top_k, docs)
def searchbyContent(self):
pass
def searchbyContentInternal(self):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
print(docs)
ids = self.relyt.add_documents(docs)

View File

@ -79,6 +79,12 @@ class ZillizKBService(KBService):
docs = retriever.get_relevant_documents(query)
return docs
def searchbyContent(self):
pass
def searchbyContentInternal(self):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs:
for k, v in doc.metadata.items():