This commit is contained in:
wvivi2023 2024-01-15 09:54:00 +08:00
parent 314f22886e
commit c5d1ff6621
5 changed files with 13 additions and 6 deletions

View File

@ -19,6 +19,7 @@ CHUNK_SIZE = 250
OVERLAP_SIZE = 50 OVERLAP_SIZE = 50
# 知识库匹配向量数量 # 知识库匹配向量数量
FIRST_VECTOR_SEARCH_TOP_K = 10
VECTOR_SEARCH_TOP_K = 3 VECTOR_SEARCH_TOP_K = 3
# 知识库匹配的距离阈值取值范围在0-1之间SCORE越小距离越小从而相关度越高 # 知识库匹配的距离阈值取值范围在0-1之间SCORE越小距离越小从而相关度越高

View File

@ -19,7 +19,7 @@ import json
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from server.chat.knowledge_base_chat import knowledge_base_chat from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS from configs import FIRST_VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import asyncio import asyncio
from server.agent import model_container from server.agent import model_container

View File

@ -3,6 +3,7 @@ from fastapi.responses import StreamingResponse
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS, from configs import (LLM_MODELS,
FIRST_VECTOR_SEARCH_TOP_K,
VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD, SCORE_THRESHOLD,
TEMPERATURE, TEMPERATURE,
@ -79,7 +80,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
max_tokens=max_tokens, max_tokens=max_tokens,
callbacks=[callback], callbacks=[callback],
) )
docs = search_docs(query, knowledge_base_name, 10, score_threshold) docs = search_docs(query, knowledge_base_name, FIRST_VECTOR_SEARCH_TOP_K, score_threshold)
# docs = await run_in_threadpool(search_docs, # docs = await run_in_threadpool(search_docs,
# query=query, # query=query,
# knowledge_base_name=knowledge_base_name, # knowledge_base_name=knowledge_base_name,

View File

@ -52,9 +52,6 @@ def search_docs(
doc_contents = [" ".join(jieba.cut(doc)) for doc in doc_contents] doc_contents = [" ".join(jieba.cut(doc)) for doc in doc_contents]
queryList = [" ".join(jieba.cut(doc)) for doc in queryList] queryList = [" ".join(jieba.cut(doc)) for doc in queryList]
#print(f"****** search_docs, doc_contents:{doc_contents}")
#print(f"****** search_docs, queryList:{queryList}")
vectorizer = TfidfVectorizer() vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform(doc_contents) tfidf_matrix = vectorizer.fit_transform(doc_contents)
print(f"****** search_docs, tfidf_matrix:{tfidf_matrix}") print(f"****** search_docs, tfidf_matrix:{tfidf_matrix}")
@ -67,7 +64,14 @@ def search_docs(
docs_with_scores = [(doc, score) for doc, score in zip(docs, cosine_similarities)] docs_with_scores = [(doc, score) for doc, score in zip(docs, cosine_similarities)]
sorted_docs = sorted(docs_with_scores, key=lambda x: x[1], reverse=True) sorted_docs = sorted(docs_with_scores, key=lambda x: x[1], reverse=True)
print(f"****** search_docs, sorted_docs:{sorted_docs}") print(f"****** search_docs, sorted_docs:{sorted_docs}")
data = [DocumentWithVSId(page_content = doc[0][0].page_content,id=doc[0][0].metadata.get("id"), score=doc[0][1],metadata=doc[0][0].metadata) for doc in docs_with_scores] i = 0
for doc in sorted_docs:
if i>=VECTOR_SEARCH_TOP_K:
break
else:
data.append(DocumentWithVSId(page_content = doc[0][0].page_content,id=doc[0][0].metadata.get("id"), score=doc[0][1],metadata=doc[0][0].metadata))
i = i+1
print(f"****** search_docs top K , sorted_docs:{data}")
else: else:
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]

View File

@ -14,6 +14,7 @@ from configs import (
CHUNK_SIZE, CHUNK_SIZE,
OVERLAP_SIZE, OVERLAP_SIZE,
ZH_TITLE_ENHANCE, ZH_TITLE_ENHANCE,
FIRST_VECTOR_SEARCH_TOP_K,
VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K, SEARCH_ENGINE_TOP_K,
HTTPX_DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT,