增加关键词检索和es向量库检索
This commit is contained in:
parent
08cb4b8750
commit
7f57a4ede8
|
|
@ -22,7 +22,7 @@ from chatchat.server.utils import (wrap_done, get_ChatOpenAI, get_default_llm,
|
|||
BaseResponse, get_prompt_template, build_logger,
|
||||
check_embed_model, api_address
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
logger = build_logger()
|
||||
|
||||
|
|
@ -60,6 +60,8 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
|||
return_direct: bool = Body(False, description="直接返回检索结果,不送入 LLM"),
|
||||
request: Request = None,
|
||||
):
|
||||
logger.info(f"kb_chat:,mode {mode}")
|
||||
start_time = time.time()
|
||||
if mode == "local_kb":
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
if kb is None:
|
||||
|
|
@ -67,6 +69,8 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
|||
|
||||
async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
|
||||
try:
|
||||
logger.info(f"***********************************knowledge_base_chat_iterator:,mode {mode}")
|
||||
start_time1 = time.time()
|
||||
nonlocal history, prompt_name, max_tokens
|
||||
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
|
@ -74,8 +78,10 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
|||
if mode == "local_kb":
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
ok, msg = kb.check_embed_model()
|
||||
logger.info(f"***********************************knowledge_base_chat_iterator:,mode {mode},kb_name:{kb_name}")
|
||||
if not ok:
|
||||
raise ValueError(msg)
|
||||
# docs = search_docs( query = query,knowledge_base_name = kb_name,top_k = top_k, score_threshold = score_threshold,)
|
||||
docs = await run_in_threadpool(search_docs,
|
||||
query=query,
|
||||
knowledge_base_name=kb_name,
|
||||
|
|
@ -83,7 +89,13 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
|||
score_threshold=score_threshold,
|
||||
file_name="",
|
||||
metadata={})
|
||||
|
||||
source_documents = format_reference(kb_name, docs, api_address(is_public=True))
|
||||
logger.info(
|
||||
f"***********************************knowledge_base_chat_iterator:,after format_reference:{docs}")
|
||||
end_time1 = time.time()
|
||||
execution_time1 = end_time1 - start_time1
|
||||
logger.info(f"kb_chat Execution time检索完成: {execution_time1:.6f} seconds")
|
||||
elif mode == "temp_kb":
|
||||
ok, msg = check_embed_model()
|
||||
if not ok:
|
||||
|
|
@ -139,6 +151,7 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
|||
if max_tokens in [None, 0]:
|
||||
max_tokens = Settings.model_settings.MAX_TOKENS
|
||||
|
||||
start_time1 = time.time()
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
|
|
@ -223,6 +236,12 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
|||
return
|
||||
|
||||
if stream:
|
||||
return EventSourceResponse(knowledge_base_chat_iterator())
|
||||
eventSource = EventSourceResponse(knowledge_base_chat_iterator())
|
||||
# 记录结束时间
|
||||
end_time = time.time()
|
||||
# 计算执行时间
|
||||
execution_time = end_time - start_time
|
||||
logger.info(f"final kb_chat Execution time: {execution_time:.6f} seconds")
|
||||
return eventSource
|
||||
else:
|
||||
return await knowledge_base_chat_iterator().__anext__()
|
||||
|
|
|
|||
|
|
@ -72,10 +72,11 @@ def search_docs(
|
|||
if kb is not None:
|
||||
if query:
|
||||
docs = kb.search_docs(query, top_k, score_threshold)
|
||||
logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold}")
|
||||
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
||||
# logger.info(f"search_docs, query:{query},top_k:{top_k},len(docs):{len(docs)},docs:{docs}")
|
||||
docs_key = kb.search_content_internal(query,2)
|
||||
# logger.info(f"before merge_and_deduplicate docs_key:{docs_key}")
|
||||
docs = merge_and_deduplicate(docs, docs_key)
|
||||
logger.info(f"after merge_and_deduplicate docs:{docs}")
|
||||
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
|
||||
elif file_name or metadata:
|
||||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||||
|
|
@ -84,26 +85,16 @@ def search_docs(
|
|||
del d.metadata["vector"]
|
||||
return [x.dict() for x in data]
|
||||
|
||||
def merge_and_deduplicate(list1: List[Tuple[Document, float]], list2: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
|
||||
# 使用字典来存储 page_content 和对应的元组 (Document, float)
|
||||
merged_dict = {}
|
||||
# 遍历第一个列表
|
||||
for item in list1:
|
||||
document, value = item
|
||||
page_content = document.page_content
|
||||
# 如果 page_content 不在字典中,将其加入字典
|
||||
if page_content not in merged_dict:
|
||||
merged_dict[page_content] = item
|
||||
def merge_and_deduplicate(list1: List[Document], list2: List[Document]) -> List[Document]:
|
||||
# 使用字典存储唯一的 Document
|
||||
merged_dict = {doc.page_content: doc for doc in list1}
|
||||
|
||||
# 遍历第二个列表
|
||||
for item in list2:
|
||||
document, value = item
|
||||
page_content = document.page_content
|
||||
# 如果 page_content 不在字典中,将其加入字典
|
||||
if page_content not in merged_dict:
|
||||
merged_dict[page_content] = item
|
||||
# 遍历 list2,将新的 Document 添加到字典
|
||||
for doc in list2:
|
||||
if doc.page_content not in merged_dict:
|
||||
merged_dict[doc.page_content] = doc
|
||||
|
||||
# 将字典的值转换为列表并返回
|
||||
# 返回去重后的列表
|
||||
return list(merged_dict.values())
|
||||
|
||||
def list_files(knowledge_base_name: str) -> ListResponse:
|
||||
|
|
|
|||
|
|
@ -213,7 +213,7 @@ class KBService(ABC):
|
|||
def search_content_internal(self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
)->List[Tuple[Document, float]]:
|
||||
)->List[Document]:
|
||||
docs = self.searchbyContentInternal(query,top_k)
|
||||
return docs
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
|
|
|
|||
|
|
@ -37,9 +37,12 @@ class ESKBService(KBService):
|
|||
self.client_cert = kb_config.get("client_cert", None)
|
||||
self.dims_length = kb_config.get("dims_length", None)
|
||||
self.embeddings_model = get_Embeddings(self.embed_model)
|
||||
logger.info(f"self.kb_path:{self.kb_path },self.index_name:{self.index_name}, self.scheme:{self.scheme},self.IP:{self.IP},"
|
||||
f"self.PORT:{self.PORT},self.user:{self.user},self.password:{self.password},self.verify_certs:{self.verify_certs},"
|
||||
f"self.client_cert:{self.client_cert},self.client_key:{self.client_key},self.dims_length:{self.dims_length}")
|
||||
try:
|
||||
connection_info = dict(
|
||||
host=f"{self.scheme}://{self.IP}:{self.PORT}"
|
||||
hosts=f"{self.scheme}://{self.IP}:{self.PORT}"
|
||||
)
|
||||
if self.user != "" and self.password != "":
|
||||
connection_info.update(basic_auth=(self.user, self.password))
|
||||
|
|
@ -53,7 +56,9 @@ class ESKBService(KBService):
|
|||
connection_info.update(client_key=self.client_key)
|
||||
connection_info.update(client_cert=self.client_cert)
|
||||
# ES python客户端连接(仅连接)
|
||||
logger.info(f"connection_info:{connection_info}")
|
||||
self.es_client_python = Elasticsearch(**connection_info)
|
||||
# logger.info(f"after Elasticsearch connection_info:{connection_info}")
|
||||
except ConnectionError:
|
||||
logger.error("连接到 Elasticsearch 失败!")
|
||||
raise ConnectionError
|
||||
|
|
@ -181,15 +186,17 @@ class ESKBService(KBService):
|
|||
search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k)
|
||||
hits = [hit for hit in search_results["hits"]["hits"]]
|
||||
docs_and_scores = [
|
||||
(
|
||||
# (
|
||||
Document(
|
||||
page_content=hit["_source"]["context"],
|
||||
metadata=hit["_source"]["metadata"],
|
||||
),
|
||||
1.3,
|
||||
)
|
||||
)
|
||||
# ,
|
||||
# 1.3,
|
||||
# )
|
||||
for hit in hits
|
||||
]
|
||||
# logger.info(f"docs_and_scores:{docs_and_scores}")
|
||||
return docs_and_scores
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
results = []
|
||||
|
|
@ -227,7 +234,7 @@ class ESKBService(KBService):
|
|||
},
|
||||
"track_total_hits": True,
|
||||
}
|
||||
print(f"***do_delete_doc: kb_file.filepath:{kb_file.filepath}, base_file_name:{base_file_name}")
|
||||
print(f"***do_delete_doc: kb_file.filepath:{kb_file.filepath}")
|
||||
# 注意设置size,默认返回10个。
|
||||
search_results = self.es_client_python.search(index=self.index_name, body=query,size=200)
|
||||
delete_list = [hit["_id"] for hit in search_results['hits']['hits']]
|
||||
|
|
|
|||
Loading…
Reference in New Issue