增加关键词检索和es向量库检索

This commit is contained in:
weiweiw 2025-01-21 12:20:39 +08:00
parent 08cb4b8750
commit 7f57a4ede8
4 changed files with 46 additions and 29 deletions

View File

@ -22,7 +22,7 @@ from chatchat.server.utils import (wrap_done, get_ChatOpenAI, get_default_llm,
BaseResponse, get_prompt_template, build_logger,
check_embed_model, api_address
)
import time
logger = build_logger()
@ -60,6 +60,8 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
return_direct: bool = Body(False, description="直接返回检索结果,不送入 LLM"),
request: Request = None,
):
logger.info(f"kb_chat:,mode {mode}")
start_time = time.time()
if mode == "local_kb":
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb is None:
@ -67,6 +69,8 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
try:
logger.info(f"***********************************knowledge_base_chat_iterator:,mode {mode}")
start_time1 = time.time()
nonlocal history, prompt_name, max_tokens
history = [History.from_data(h) for h in history]
@ -74,8 +78,10 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
if mode == "local_kb":
kb = KBServiceFactory.get_service_by_name(kb_name)
ok, msg = kb.check_embed_model()
logger.info(f"***********************************knowledge_base_chat_iterator:,mode {mode}kb_name{kb_name}")
if not ok:
raise ValueError(msg)
# docs = search_docs( query = query,knowledge_base_name = kb_name,top_k = top_k, score_threshold = score_threshold,)
docs = await run_in_threadpool(search_docs,
query=query,
knowledge_base_name=kb_name,
@ -83,7 +89,13 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
score_threshold=score_threshold,
file_name="",
metadata={})
source_documents = format_reference(kb_name, docs, api_address(is_public=True))
logger.info(
f"***********************************knowledge_base_chat_iterator:,after format_reference:{docs}")
end_time1 = time.time()
execution_time1 = end_time1 - start_time1
logger.info(f"kb_chat Execution time检索完成: {execution_time1:.6f} seconds")
elif mode == "temp_kb":
ok, msg = check_embed_model()
if not ok:
@ -139,6 +151,7 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
if max_tokens in [None, 0]:
max_tokens = Settings.model_settings.MAX_TOKENS
start_time1 = time.time()
llm = get_ChatOpenAI(
model_name=model,
temperature=temperature,
@ -223,6 +236,12 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
return
if stream:
return EventSourceResponse(knowledge_base_chat_iterator())
eventSource = EventSourceResponse(knowledge_base_chat_iterator())
# 记录结束时间
end_time = time.time()
# 计算执行时间
execution_time = end_time - start_time
logger.info(f"final kb_chat Execution time: {execution_time:.6f} seconds")
return eventSource
else:
return await knowledge_base_chat_iterator().__anext__()

View File

@ -72,10 +72,11 @@ def search_docs(
if kb is not None:
if query:
docs = kb.search_docs(query, top_k, score_threshold)
logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold}")
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
# logger.info(f"search_docs, query:{query},top_k:{top_k},len(docs):{len(docs)},docs:{docs}")
docs_key = kb.search_content_internal(query,2)
# logger.info(f"before merge_and_deduplicate docs_key:{docs_key}")
docs = merge_and_deduplicate(docs, docs_key)
logger.info(f"after merge_and_deduplicate docs:{docs}")
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata)
@ -84,26 +85,16 @@ def search_docs(
del d.metadata["vector"]
return [x.dict() for x in data]
def merge_and_deduplicate(list1: List[Tuple[Document, float]], list2: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
# 使用字典来存储 page_content 和对应的元组 (Document, float)
merged_dict = {}
# 遍历第一个列表
for item in list1:
document, value = item
page_content = document.page_content
# 如果 page_content 不在字典中,将其加入字典
if page_content not in merged_dict:
merged_dict[page_content] = item
def merge_and_deduplicate(list1: List[Document], list2: List[Document]) -> List[Document]:
# 使用字典存储唯一的 Document
merged_dict = {doc.page_content: doc for doc in list1}
# 遍历第二个列表
for item in list2:
document, value = item
page_content = document.page_content
# 如果 page_content 不在字典中,将其加入字典
if page_content not in merged_dict:
merged_dict[page_content] = item
# 遍历 list2将新的 Document 添加到字典
for doc in list2:
if doc.page_content not in merged_dict:
merged_dict[doc.page_content] = doc
# 将字典的值转换为列表并返回
# 返回去重后的列表
return list(merged_dict.values())
def list_files(knowledge_base_name: str) -> ListResponse:

View File

@ -213,7 +213,7 @@ class KBService(ABC):
def search_content_internal(self,
query: str,
top_k: int,
)->List[Tuple[Document, float]]:
)->List[Document]:
docs = self.searchbyContentInternal(query,top_k)
return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:

View File

@ -37,9 +37,12 @@ class ESKBService(KBService):
self.client_cert = kb_config.get("client_cert", None)
self.dims_length = kb_config.get("dims_length", None)
self.embeddings_model = get_Embeddings(self.embed_model)
logger.info(f"self.kb_path:{self.kb_path },self.index_name:{self.index_name}, self.scheme:{self.scheme},self.IP:{self.IP},"
f"self.PORT:{self.PORT},self.user:{self.user},self.password:{self.password},self.verify_certs:{self.verify_certs},"
f"self.client_cert:{self.client_cert},self.client_key:{self.client_key},self.dims_length:{self.dims_length}")
try:
connection_info = dict(
host=f"{self.scheme}://{self.IP}:{self.PORT}"
hosts=f"{self.scheme}://{self.IP}:{self.PORT}"
)
if self.user != "" and self.password != "":
connection_info.update(basic_auth=(self.user, self.password))
@ -53,7 +56,9 @@ class ESKBService(KBService):
connection_info.update(client_key=self.client_key)
connection_info.update(client_cert=self.client_cert)
# ES python客户端连接仅连接
logger.info(f"connection_info:{connection_info}")
self.es_client_python = Elasticsearch(**connection_info)
# logger.info(f"after Elasticsearch connection_info:{connection_info}")
except ConnectionError:
logger.error("连接到 Elasticsearch 失败!")
raise ConnectionError
@ -181,15 +186,17 @@ class ESKBService(KBService):
search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k)
hits = [hit for hit in search_results["hits"]["hits"]]
docs_and_scores = [
(
# (
Document(
page_content=hit["_source"]["context"],
metadata=hit["_source"]["metadata"],
),
1.3,
)
)
# ,
# 1.3,
# )
for hit in hits
]
# logger.info(f"docs_and_scores:{docs_and_scores}")
return docs_and_scores
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
results = []
@ -227,7 +234,7 @@ class ESKBService(KBService):
},
"track_total_hits": True,
}
print(f"***do_delete_doc: kb_file.filepath:{kb_file.filepath}, base_file_name:{base_file_name}")
print(f"***do_delete_doc: kb_file.filepath:{kb_file.filepath}")
# 注意设置size默认返回10个。
search_results = self.es_client_python.search(index=self.index_name, body=query,size=200)
delete_list = [hit["_id"] for hit in search_results['hits']['hits']]