From c5d1ff6621d36cf9932a0432c2801d78b7085ffa Mon Sep 17 00:00:00 2001 From: wvivi2023 Date: Mon, 15 Jan 2024 09:54:00 +0800 Subject: [PATCH] enhance --- configs/kb_config.py.example | 1 + server/agent/tools/search_knowledgebase_once.py | 2 +- server/chat/knowledge_base_chat.py | 3 ++- server/knowledge_base/kb_doc_api.py | 12 ++++++++---- webui_pages/utils.py | 1 + 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example index 731148a..02ded35 100644 --- a/configs/kb_config.py.example +++ b/configs/kb_config.py.example @@ -19,6 +19,7 @@ CHUNK_SIZE = 250 OVERLAP_SIZE = 50 # 知识库匹配向量数量 +FIRST_VECTOR_SEARCH_TOP_K = 10 VECTOR_SEARCH_TOP_K = 3 # 知识库匹配的距离阈值,取值范围在0-1之间,SCORE越小,距离越小从而相关度越高, diff --git a/server/agent/tools/search_knowledgebase_once.py b/server/agent/tools/search_knowledgebase_once.py index c9a2d7b..c736559 100644 --- a/server/agent/tools/search_knowledgebase_once.py +++ b/server/agent/tools/search_knowledgebase_once.py @@ -19,7 +19,7 @@ import json sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from server.chat.knowledge_base_chat import knowledge_base_chat -from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS +from configs import FIRST_VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS import asyncio from server.agent import model_container diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 66afef2..a8ff596 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -3,6 +3,7 @@ from fastapi.responses import StreamingResponse from sse_starlette.sse import EventSourceResponse from fastapi.concurrency import run_in_threadpool from configs import (LLM_MODELS, + FIRST_VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE, @@ -79,7 +80,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", max_tokens=max_tokens, callbacks=[callback], ) - docs = search_docs(query, knowledge_base_name, 10, score_threshold) + docs = search_docs(query, knowledge_base_name, FIRST_VECTOR_SEARCH_TOP_K, score_threshold) # docs = await run_in_threadpool(search_docs, # query=query, # knowledge_base_name=knowledge_base_name, diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 96e984a..0c82dc6 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -52,9 +52,6 @@ def search_docs( doc_contents = [" ".join(jieba.cut(doc)) for doc in doc_contents] queryList = [" ".join(jieba.cut(doc)) for doc in queryList] - #print(f"****** search_docs, doc_contents:{doc_contents}") - #print(f"****** search_docs, queryList:{queryList}") - vectorizer = TfidfVectorizer() tfidf_matrix = vectorizer.fit_transform(doc_contents) print(f"****** search_docs, tfidf_matrix:{tfidf_matrix}") @@ -67,7 +64,14 @@ def search_docs( docs_with_scores = [(doc, score) for doc, score in zip(docs, cosine_similarities)] sorted_docs = sorted(docs_with_scores, key=lambda x: x[1], reverse=True) print(f"****** search_docs, sorted_docs:{sorted_docs}") - data = [DocumentWithVSId(page_content = doc[0][0].page_content,id=doc[0][0].metadata.get("id"), score=doc[0][1],metadata=doc[0][0].metadata) for doc in docs_with_scores] + i = 0 + for doc in sorted_docs: + if i>=VECTOR_SEARCH_TOP_K: + break + else: + data.append(DocumentWithVSId(page_content = doc[0][0].page_content,id=doc[0][0].metadata.get("id"), score=doc[0][1],metadata=doc[0][0].metadata)) + i = i+1 + print(f"****** search_docs top K , sorted_docs:{data}") else: data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 2f642d8..b7c5967 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -14,6 +14,7 @@ from configs import ( CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, + FIRST_VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K, SEARCH_ENGINE_TOP_K, HTTPX_DEFAULT_TIMEOUT,