增加关键词检索和es向量库检索
This commit is contained in:
parent
08cb4b8750
commit
7f57a4ede8
|
|
@ -22,7 +22,7 @@ from chatchat.server.utils import (wrap_done, get_ChatOpenAI, get_default_llm,
|
||||||
BaseResponse, get_prompt_template, build_logger,
|
BaseResponse, get_prompt_template, build_logger,
|
||||||
check_embed_model, api_address
|
check_embed_model, api_address
|
||||||
)
|
)
|
||||||
|
import time
|
||||||
|
|
||||||
logger = build_logger()
|
logger = build_logger()
|
||||||
|
|
||||||
|
|
@ -60,6 +60,8 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
||||||
return_direct: bool = Body(False, description="直接返回检索结果,不送入 LLM"),
|
return_direct: bool = Body(False, description="直接返回检索结果,不送入 LLM"),
|
||||||
request: Request = None,
|
request: Request = None,
|
||||||
):
|
):
|
||||||
|
logger.info(f"kb_chat:,mode {mode}")
|
||||||
|
start_time = time.time()
|
||||||
if mode == "local_kb":
|
if mode == "local_kb":
|
||||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||||
if kb is None:
|
if kb is None:
|
||||||
|
|
@ -67,6 +69,8 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
||||||
|
|
||||||
async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
|
async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"***********************************knowledge_base_chat_iterator:,mode {mode}")
|
||||||
|
start_time1 = time.time()
|
||||||
nonlocal history, prompt_name, max_tokens
|
nonlocal history, prompt_name, max_tokens
|
||||||
|
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
|
|
@ -74,8 +78,10 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
||||||
if mode == "local_kb":
|
if mode == "local_kb":
|
||||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||||
ok, msg = kb.check_embed_model()
|
ok, msg = kb.check_embed_model()
|
||||||
|
logger.info(f"***********************************knowledge_base_chat_iterator:,mode {mode},kb_name:{kb_name}")
|
||||||
if not ok:
|
if not ok:
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
# docs = search_docs( query = query,knowledge_base_name = kb_name,top_k = top_k, score_threshold = score_threshold,)
|
||||||
docs = await run_in_threadpool(search_docs,
|
docs = await run_in_threadpool(search_docs,
|
||||||
query=query,
|
query=query,
|
||||||
knowledge_base_name=kb_name,
|
knowledge_base_name=kb_name,
|
||||||
|
|
@ -83,7 +89,13 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
file_name="",
|
file_name="",
|
||||||
metadata={})
|
metadata={})
|
||||||
|
|
||||||
source_documents = format_reference(kb_name, docs, api_address(is_public=True))
|
source_documents = format_reference(kb_name, docs, api_address(is_public=True))
|
||||||
|
logger.info(
|
||||||
|
f"***********************************knowledge_base_chat_iterator:,after format_reference:{docs}")
|
||||||
|
end_time1 = time.time()
|
||||||
|
execution_time1 = end_time1 - start_time1
|
||||||
|
logger.info(f"kb_chat Execution time检索完成: {execution_time1:.6f} seconds")
|
||||||
elif mode == "temp_kb":
|
elif mode == "temp_kb":
|
||||||
ok, msg = check_embed_model()
|
ok, msg = check_embed_model()
|
||||||
if not ok:
|
if not ok:
|
||||||
|
|
@ -139,6 +151,7 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
||||||
if max_tokens in [None, 0]:
|
if max_tokens in [None, 0]:
|
||||||
max_tokens = Settings.model_settings.MAX_TOKENS
|
max_tokens = Settings.model_settings.MAX_TOKENS
|
||||||
|
|
||||||
|
start_time1 = time.time()
|
||||||
llm = get_ChatOpenAI(
|
llm = get_ChatOpenAI(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
|
@ -223,6 +236,12 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
||||||
return
|
return
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return EventSourceResponse(knowledge_base_chat_iterator())
|
eventSource = EventSourceResponse(knowledge_base_chat_iterator())
|
||||||
|
# 记录结束时间
|
||||||
|
end_time = time.time()
|
||||||
|
# 计算执行时间
|
||||||
|
execution_time = end_time - start_time
|
||||||
|
logger.info(f"final kb_chat Execution time: {execution_time:.6f} seconds")
|
||||||
|
return eventSource
|
||||||
else:
|
else:
|
||||||
return await knowledge_base_chat_iterator().__anext__()
|
return await knowledge_base_chat_iterator().__anext__()
|
||||||
|
|
|
||||||
|
|
@ -72,10 +72,11 @@ def search_docs(
|
||||||
if kb is not None:
|
if kb is not None:
|
||||||
if query:
|
if query:
|
||||||
docs = kb.search_docs(query, top_k, score_threshold)
|
docs = kb.search_docs(query, top_k, score_threshold)
|
||||||
logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold}")
|
# logger.info(f"search_docs, query:{query},top_k:{top_k},len(docs):{len(docs)},docs:{docs}")
|
||||||
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
|
||||||
docs_key = kb.search_content_internal(query,2)
|
docs_key = kb.search_content_internal(query,2)
|
||||||
|
# logger.info(f"before merge_and_deduplicate docs_key:{docs_key}")
|
||||||
docs = merge_and_deduplicate(docs, docs_key)
|
docs = merge_and_deduplicate(docs, docs_key)
|
||||||
|
logger.info(f"after merge_and_deduplicate docs:{docs}")
|
||||||
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
|
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
|
||||||
elif file_name or metadata:
|
elif file_name or metadata:
|
||||||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||||||
|
|
@ -84,26 +85,16 @@ def search_docs(
|
||||||
del d.metadata["vector"]
|
del d.metadata["vector"]
|
||||||
return [x.dict() for x in data]
|
return [x.dict() for x in data]
|
||||||
|
|
||||||
def merge_and_deduplicate(list1: List[Tuple[Document, float]], list2: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
|
def merge_and_deduplicate(list1: List[Document], list2: List[Document]) -> List[Document]:
|
||||||
# 使用字典来存储 page_content 和对应的元组 (Document, float)
|
# 使用字典存储唯一的 Document
|
||||||
merged_dict = {}
|
merged_dict = {doc.page_content: doc for doc in list1}
|
||||||
# 遍历第一个列表
|
|
||||||
for item in list1:
|
|
||||||
document, value = item
|
|
||||||
page_content = document.page_content
|
|
||||||
# 如果 page_content 不在字典中,将其加入字典
|
|
||||||
if page_content not in merged_dict:
|
|
||||||
merged_dict[page_content] = item
|
|
||||||
|
|
||||||
# 遍历第二个列表
|
# 遍历 list2,将新的 Document 添加到字典
|
||||||
for item in list2:
|
for doc in list2:
|
||||||
document, value = item
|
if doc.page_content not in merged_dict:
|
||||||
page_content = document.page_content
|
merged_dict[doc.page_content] = doc
|
||||||
# 如果 page_content 不在字典中,将其加入字典
|
|
||||||
if page_content not in merged_dict:
|
|
||||||
merged_dict[page_content] = item
|
|
||||||
|
|
||||||
# 将字典的值转换为列表并返回
|
# 返回去重后的列表
|
||||||
return list(merged_dict.values())
|
return list(merged_dict.values())
|
||||||
|
|
||||||
def list_files(knowledge_base_name: str) -> ListResponse:
|
def list_files(knowledge_base_name: str) -> ListResponse:
|
||||||
|
|
|
||||||
|
|
@ -213,7 +213,7 @@ class KBService(ABC):
|
||||||
def search_content_internal(self,
|
def search_content_internal(self,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
)->List[Tuple[Document, float]]:
|
)->List[Document]:
|
||||||
docs = self.searchbyContentInternal(query,top_k)
|
docs = self.searchbyContentInternal(query,top_k)
|
||||||
return docs
|
return docs
|
||||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
|
|
|
||||||
|
|
@ -37,9 +37,12 @@ class ESKBService(KBService):
|
||||||
self.client_cert = kb_config.get("client_cert", None)
|
self.client_cert = kb_config.get("client_cert", None)
|
||||||
self.dims_length = kb_config.get("dims_length", None)
|
self.dims_length = kb_config.get("dims_length", None)
|
||||||
self.embeddings_model = get_Embeddings(self.embed_model)
|
self.embeddings_model = get_Embeddings(self.embed_model)
|
||||||
|
logger.info(f"self.kb_path:{self.kb_path },self.index_name:{self.index_name}, self.scheme:{self.scheme},self.IP:{self.IP},"
|
||||||
|
f"self.PORT:{self.PORT},self.user:{self.user},self.password:{self.password},self.verify_certs:{self.verify_certs},"
|
||||||
|
f"self.client_cert:{self.client_cert},self.client_key:{self.client_key},self.dims_length:{self.dims_length}")
|
||||||
try:
|
try:
|
||||||
connection_info = dict(
|
connection_info = dict(
|
||||||
host=f"{self.scheme}://{self.IP}:{self.PORT}"
|
hosts=f"{self.scheme}://{self.IP}:{self.PORT}"
|
||||||
)
|
)
|
||||||
if self.user != "" and self.password != "":
|
if self.user != "" and self.password != "":
|
||||||
connection_info.update(basic_auth=(self.user, self.password))
|
connection_info.update(basic_auth=(self.user, self.password))
|
||||||
|
|
@ -53,7 +56,9 @@ class ESKBService(KBService):
|
||||||
connection_info.update(client_key=self.client_key)
|
connection_info.update(client_key=self.client_key)
|
||||||
connection_info.update(client_cert=self.client_cert)
|
connection_info.update(client_cert=self.client_cert)
|
||||||
# ES python客户端连接(仅连接)
|
# ES python客户端连接(仅连接)
|
||||||
|
logger.info(f"connection_info:{connection_info}")
|
||||||
self.es_client_python = Elasticsearch(**connection_info)
|
self.es_client_python = Elasticsearch(**connection_info)
|
||||||
|
# logger.info(f"after Elasticsearch connection_info:{connection_info}")
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
logger.error("连接到 Elasticsearch 失败!")
|
logger.error("连接到 Elasticsearch 失败!")
|
||||||
raise ConnectionError
|
raise ConnectionError
|
||||||
|
|
@ -181,15 +186,17 @@ class ESKBService(KBService):
|
||||||
search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k)
|
search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k)
|
||||||
hits = [hit for hit in search_results["hits"]["hits"]]
|
hits = [hit for hit in search_results["hits"]["hits"]]
|
||||||
docs_and_scores = [
|
docs_and_scores = [
|
||||||
(
|
# (
|
||||||
Document(
|
Document(
|
||||||
page_content=hit["_source"]["context"],
|
page_content=hit["_source"]["context"],
|
||||||
metadata=hit["_source"]["metadata"],
|
metadata=hit["_source"]["metadata"],
|
||||||
),
|
)
|
||||||
1.3,
|
# ,
|
||||||
)
|
# 1.3,
|
||||||
|
# )
|
||||||
for hit in hits
|
for hit in hits
|
||||||
]
|
]
|
||||||
|
# logger.info(f"docs_and_scores:{docs_and_scores}")
|
||||||
return docs_and_scores
|
return docs_and_scores
|
||||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
results = []
|
results = []
|
||||||
|
|
@ -227,7 +234,7 @@ class ESKBService(KBService):
|
||||||
},
|
},
|
||||||
"track_total_hits": True,
|
"track_total_hits": True,
|
||||||
}
|
}
|
||||||
print(f"***do_delete_doc: kb_file.filepath:{kb_file.filepath}, base_file_name:{base_file_name}")
|
print(f"***do_delete_doc: kb_file.filepath:{kb_file.filepath}")
|
||||||
# 注意设置size,默认返回10个。
|
# 注意设置size,默认返回10个。
|
||||||
search_results = self.es_client_python.search(index=self.index_name, body=query,size=200)
|
search_results = self.es_client_python.search(index=self.index_name, body=query,size=200)
|
||||||
delete_list = [hit["_id"] for hit in search_results['hits']['hits']]
|
delete_list = [hit["_id"] for hit in search_results['hits']['hits']]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue