解决es的问题
This commit is contained in:
parent
39af20ed13
commit
194437a271
|
|
@ -91,8 +91,8 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
||||||
metadata={})
|
metadata={})
|
||||||
|
|
||||||
source_documents = format_reference(kb_name, docs, api_address(is_public=True))
|
source_documents = format_reference(kb_name, docs, api_address(is_public=True))
|
||||||
logger.info(
|
# logger.info(
|
||||||
f"***********************************knowledge_base_chat_iterator:,after format_reference:{docs}")
|
# f"***********************************knowledge_base_chat_iterator:,after format_reference:{docs}")
|
||||||
end_time1 = time.time()
|
end_time1 = time.time()
|
||||||
execution_time1 = end_time1 - start_time1
|
execution_time1 = end_time1 - start_time1
|
||||||
logger.info(f"kb_chat Execution time检索完成: {execution_time1:.6f} seconds")
|
logger.info(f"kb_chat Execution time检索完成: {execution_time1:.6f} seconds")
|
||||||
|
|
|
||||||
|
|
@ -72,11 +72,15 @@ def search_docs(
|
||||||
if kb is not None:
|
if kb is not None:
|
||||||
if query:
|
if query:
|
||||||
docs = kb.search_docs(query, top_k, score_threshold)
|
docs = kb.search_docs(query, top_k, score_threshold)
|
||||||
# logger.info(f"search_docs, query:{query},top_k:{top_k},len(docs):{len(docs)},docs:{docs}")
|
if docs is not None:
|
||||||
|
logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold},len(docs):{len(docs)}")
|
||||||
|
|
||||||
docs_key = kb.search_content_internal(query,2)
|
docs_key = kb.search_content_internal(query,2)
|
||||||
# logger.info(f"before merge_and_deduplicate docs_key:{docs_key}")
|
if docs_key is not None:
|
||||||
|
logger.info(f"before merge_and_deduplicate ,len(docs_key):{len(docs_key)}")
|
||||||
docs = merge_and_deduplicate(docs, docs_key)
|
docs = merge_and_deduplicate(docs, docs_key)
|
||||||
logger.info(f"after merge_and_deduplicate docs:{docs}")
|
if docs is not None:
|
||||||
|
logger.info(f"after merge_and_deduplicate len(docs):{len(docs)}")
|
||||||
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
|
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
|
||||||
elif file_name or metadata:
|
elif file_name or metadata:
|
||||||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||||||
|
|
@ -87,9 +91,12 @@ def search_docs(
|
||||||
|
|
||||||
def merge_and_deduplicate(list1: List[Document], list2: List[Document]) -> List[Document]:
|
def merge_and_deduplicate(list1: List[Document], list2: List[Document]) -> List[Document]:
|
||||||
# 使用字典存储唯一的 Document
|
# 使用字典存储唯一的 Document
|
||||||
|
merged_dict = {}
|
||||||
|
if list1 is not None:
|
||||||
merged_dict = {doc.page_content: doc for doc in list1}
|
merged_dict = {doc.page_content: doc for doc in list1}
|
||||||
|
|
||||||
# 遍历 list2,将新的 Document 添加到字典
|
# 遍历 list2,将新的 Document 添加到字典
|
||||||
|
if list2 is not None:
|
||||||
for doc in list2:
|
for doc in list2:
|
||||||
if doc.page_content not in merged_dict:
|
if doc.page_content not in merged_dict:
|
||||||
merged_dict[doc.page_content] = doc
|
merged_dict[doc.page_content] = doc
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ class ESKBService(KBService):
|
||||||
# ES python客户端连接(仅连接)
|
# ES python客户端连接(仅连接)
|
||||||
logger.info(f"connection_info:{connection_info}")
|
logger.info(f"connection_info:{connection_info}")
|
||||||
self.es_client_python = Elasticsearch(**connection_info)
|
self.es_client_python = Elasticsearch(**connection_info)
|
||||||
# logger.info(f"after Elasticsearch connection_info:{connection_info}")
|
logger.info(f"after Elasticsearch connection_info:{connection_info}")
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
logger.error("连接到 Elasticsearch 失败!")
|
logger.error("连接到 Elasticsearch 失败!")
|
||||||
raise ConnectionError
|
raise ConnectionError
|
||||||
|
|
@ -89,9 +89,10 @@ class ESKBService(KBService):
|
||||||
es_url=f"{self.scheme}://{self.IP}:{self.PORT}",
|
es_url=f"{self.scheme}://{self.IP}:{self.PORT}",
|
||||||
index_name=self.index_name,
|
index_name=self.index_name,
|
||||||
query_field="context",
|
query_field="context",
|
||||||
|
distance_strategy="COSINE",
|
||||||
vector_query_field="dense_vector",
|
vector_query_field="dense_vector",
|
||||||
embedding=self.embeddings_model,
|
embedding=self.embeddings_model,
|
||||||
strategy=ApproxRetrievalStrategy(),
|
# strategy=ApproxRetrievalStrategy(),
|
||||||
es_params={
|
es_params={
|
||||||
"timeout": 60,
|
"timeout": 60,
|
||||||
},
|
},
|
||||||
|
|
@ -106,6 +107,7 @@ class ESKBService(KBService):
|
||||||
params["es_params"].update(client_key=self.client_key)
|
params["es_params"].update(client_key=self.client_key)
|
||||||
params["es_params"].update(client_cert=self.client_cert)
|
params["es_params"].update(client_cert=self.client_cert)
|
||||||
self.db = ElasticsearchStore(**params)
|
self.db = ElasticsearchStore(**params)
|
||||||
|
logger.info(f"after ElasticsearchStore create params:{params}")
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
logger.error("### 初始化 Elasticsearch 失败!")
|
logger.error("### 初始化 Elasticsearch 失败!")
|
||||||
raise ConnectionError
|
raise ConnectionError
|
||||||
|
|
@ -138,14 +140,20 @@ class ESKBService(KBService):
|
||||||
def vs_type(self) -> str:
|
def vs_type(self) -> str:
|
||||||
return SupportedVSType.ES
|
return SupportedVSType.ES
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
def do_search(self, query: str, top_k: int, score_threshold: float)->List[Document]:
|
||||||
|
# 确保 ElasticsearchStore 正确初始化
|
||||||
|
if not hasattr(self, "db") or self.db is None:
|
||||||
|
raise ValueError("ElasticsearchStore (db) not initialized.")
|
||||||
|
|
||||||
# 文本相似性检索
|
# 文本相似性检索
|
||||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||||
self.db,
|
self.db,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
docs = retriever.get_relevant_documents(query)
|
docs = retriever.get_relevant_documents(query)
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def searchbyContent(self, query:str, top_k: int = 2):
|
def searchbyContent(self, query:str, top_k: int = 2):
|
||||||
|
|
|
||||||
|
|
@ -78,11 +78,11 @@ class FaissKBService(KBService):
|
||||||
docs = retriever.get_relevant_documents(query)
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def searchbyContent(self):
|
def searchbyContent(self, query:str, top_k: int = 2):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def searchbyContentInternal(self):
|
def searchbyContentInternal(self, query:str, top_k: int = 2):
|
||||||
pass
|
return None
|
||||||
|
|
||||||
def do_add_doc(
|
def do_add_doc(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -88,10 +88,10 @@ class MilvusKBService(KBService):
|
||||||
docs = retriever.get_relevant_documents(query)
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def searchbyContent(self):
|
def searchbyContent(self, query:str, top_k: int = 2):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def searchbyContentInternal(self):
|
def searchbyContentInternal(self, query:str, top_k: int = 2):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
|
|
|
||||||
|
|
@ -84,10 +84,10 @@ class PGKBService(KBService):
|
||||||
docs = retriever.get_relevant_documents(query)
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def searchbyContent(self):
|
def searchbyContent(self, query:str, top_k: int = 2):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def searchbyContentInternal(self):
|
def searchbyContentInternal(self, query:str, top_k: int = 2):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
|
|
|
||||||
|
|
@ -95,10 +95,10 @@ class RelytKBService(KBService):
|
||||||
return score_threshold_process(score_threshold, top_k, docs)
|
return score_threshold_process(score_threshold, top_k, docs)
|
||||||
|
|
||||||
|
|
||||||
def searchbyContent(self):
|
def searchbyContent(self, query:str, top_k: int = 2):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def searchbyContentInternal(self):
|
def searchbyContentInternal(self, query:str, top_k: int = 2):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
|
|
|
||||||
|
|
@ -79,10 +79,10 @@ class ZillizKBService(KBService):
|
||||||
docs = retriever.get_relevant_documents(query)
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def searchbyContent(self):
|
def searchbyContent(self, query:str, top_k: int = 2):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def searchbyContentInternal(self):
|
def searchbyContentInternal(self, query:str, top_k: int = 2):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
st.image(
|
st.image(
|
||||||
get_img_base64("logo-long-chatchat-trans-v2.png"), use_container_width=True
|
get_img_base64("logo-long-chatchat-trans-v2.png"), use_column_width=True
|
||||||
)
|
)
|
||||||
st.caption(
|
st.caption(
|
||||||
f"""<p align="right">当前版本:{__version__}</p>""",
|
f"""<p align="right">当前版本:{__version__}</p>""",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue