解决es的问题
This commit is contained in:
parent
39af20ed13
commit
194437a271
|
|
@ -91,8 +91,8 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
|
|||
metadata={})
|
||||
|
||||
source_documents = format_reference(kb_name, docs, api_address(is_public=True))
|
||||
logger.info(
|
||||
f"***********************************knowledge_base_chat_iterator:,after format_reference:{docs}")
|
||||
# logger.info(
|
||||
# f"***********************************knowledge_base_chat_iterator:,after format_reference:{docs}")
|
||||
end_time1 = time.time()
|
||||
execution_time1 = end_time1 - start_time1
|
||||
logger.info(f"kb_chat Execution time检索完成: {execution_time1:.6f} seconds")
|
||||
|
|
|
|||
|
|
@ -72,11 +72,15 @@ def search_docs(
|
|||
if kb is not None:
|
||||
if query:
|
||||
docs = kb.search_docs(query, top_k, score_threshold)
|
||||
# logger.info(f"search_docs, query:{query},top_k:{top_k},len(docs):{len(docs)},docs:{docs}")
|
||||
if docs is not None:
|
||||
logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold},len(docs):{len(docs)}")
|
||||
|
||||
docs_key = kb.search_content_internal(query,2)
|
||||
# logger.info(f"before merge_and_deduplicate docs_key:{docs_key}")
|
||||
if docs_key is not None:
|
||||
logger.info(f"before merge_and_deduplicate ,len(docs_key):{len(docs_key)}")
|
||||
docs = merge_and_deduplicate(docs, docs_key)
|
||||
logger.info(f"after merge_and_deduplicate docs:{docs}")
|
||||
if docs is not None:
|
||||
logger.info(f"after merge_and_deduplicate len(docs):{len(docs)}")
|
||||
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
|
||||
elif file_name or metadata:
|
||||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||||
|
|
@ -87,12 +91,15 @@ def search_docs(
|
|||
|
||||
def merge_and_deduplicate(list1: List[Document], list2: List[Document]) -> List[Document]:
|
||||
# 使用字典存储唯一的 Document
|
||||
merged_dict = {doc.page_content: doc for doc in list1}
|
||||
merged_dict = {}
|
||||
if list1 is not None:
|
||||
merged_dict = {doc.page_content: doc for doc in list1}
|
||||
|
||||
# 遍历 list2,将新的 Document 添加到字典
|
||||
for doc in list2:
|
||||
if doc.page_content not in merged_dict:
|
||||
merged_dict[doc.page_content] = doc
|
||||
if list2 is not None:
|
||||
for doc in list2:
|
||||
if doc.page_content not in merged_dict:
|
||||
merged_dict[doc.page_content] = doc
|
||||
|
||||
# 返回去重后的列表
|
||||
return list(merged_dict.values())
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class ESKBService(KBService):
|
|||
# ES python客户端连接(仅连接)
|
||||
logger.info(f"connection_info:{connection_info}")
|
||||
self.es_client_python = Elasticsearch(**connection_info)
|
||||
# logger.info(f"after Elasticsearch connection_info:{connection_info}")
|
||||
logger.info(f"after Elasticsearch connection_info:{connection_info}")
|
||||
except ConnectionError:
|
||||
logger.error("连接到 Elasticsearch 失败!")
|
||||
raise ConnectionError
|
||||
|
|
@ -89,9 +89,10 @@ class ESKBService(KBService):
|
|||
es_url=f"{self.scheme}://{self.IP}:{self.PORT}",
|
||||
index_name=self.index_name,
|
||||
query_field="context",
|
||||
distance_strategy="COSINE",
|
||||
vector_query_field="dense_vector",
|
||||
embedding=self.embeddings_model,
|
||||
strategy=ApproxRetrievalStrategy(),
|
||||
# strategy=ApproxRetrievalStrategy(),
|
||||
es_params={
|
||||
"timeout": 60,
|
||||
},
|
||||
|
|
@ -106,6 +107,7 @@ class ESKBService(KBService):
|
|||
params["es_params"].update(client_key=self.client_key)
|
||||
params["es_params"].update(client_cert=self.client_cert)
|
||||
self.db = ElasticsearchStore(**params)
|
||||
logger.info(f"after ElasticsearchStore create params:{params}")
|
||||
except ConnectionError:
|
||||
logger.error("### 初始化 Elasticsearch 失败!")
|
||||
raise ConnectionError
|
||||
|
|
@ -138,14 +140,20 @@ class ESKBService(KBService):
|
|||
def vs_type(self) -> str:
|
||||
return SupportedVSType.ES
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float)->List[Document]:
|
||||
# 确保 ElasticsearchStore 正确初始化
|
||||
if not hasattr(self, "db") or self.db is None:
|
||||
raise ValueError("ElasticsearchStore (db) not initialized.")
|
||||
|
||||
# 文本相似性检索
|
||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||
self.db,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
|
||||
return docs
|
||||
|
||||
def searchbyContent(self, query:str, top_k: int = 2):
|
||||
|
|
|
|||
|
|
@ -78,11 +78,11 @@ class FaissKBService(KBService):
|
|||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def searchbyContent(self):
|
||||
def searchbyContent(self, query:str, top_k: int = 2):
|
||||
pass
|
||||
|
||||
def searchbyContentInternal(self):
|
||||
pass
|
||||
def searchbyContentInternal(self, query:str, top_k: int = 2):
|
||||
return None
|
||||
|
||||
def do_add_doc(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -88,10 +88,10 @@ class MilvusKBService(KBService):
|
|||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def searchbyContent(self):
|
||||
def searchbyContent(self, query:str, top_k: int = 2):
|
||||
pass
|
||||
|
||||
def searchbyContentInternal(self):
|
||||
def searchbyContentInternal(self, query:str, top_k: int = 2):
|
||||
pass
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
|
|
|
|||
|
|
@ -84,10 +84,10 @@ class PGKBService(KBService):
|
|||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def searchbyContent(self):
|
||||
def searchbyContent(self, query:str, top_k: int = 2):
|
||||
pass
|
||||
|
||||
def searchbyContentInternal(self):
|
||||
def searchbyContentInternal(self, query:str, top_k: int = 2):
|
||||
pass
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
|
|
|
|||
|
|
@ -95,10 +95,10 @@ class RelytKBService(KBService):
|
|||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
|
||||
|
||||
def searchbyContent(self):
|
||||
def searchbyContent(self, query:str, top_k: int = 2):
|
||||
pass
|
||||
|
||||
def searchbyContentInternal(self):
|
||||
def searchbyContentInternal(self, query:str, top_k: int = 2):
|
||||
pass
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
|
|
|
|||
|
|
@ -79,10 +79,10 @@ class ZillizKBService(KBService):
|
|||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def searchbyContent(self):
|
||||
def searchbyContent(self, query:str, top_k: int = 2):
|
||||
pass
|
||||
|
||||
def searchbyContentInternal(self):
|
||||
def searchbyContentInternal(self, query:str, top_k: int = 2):
|
||||
pass
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ if __name__ == "__main__":
|
|||
|
||||
with st.sidebar:
|
||||
st.image(
|
||||
get_img_base64("logo-long-chatchat-trans-v2.png"), use_container_width=True
|
||||
get_img_base64("logo-long-chatchat-trans-v2.png"), use_column_width=True
|
||||
)
|
||||
st.caption(
|
||||
f"""<p align="right">当前版本:{__version__}</p>""",
|
||||
|
|
|
|||
Loading…
Reference in New Issue