diff --git a/server/chat/bing_search_chat.py b/server/chat/bing_search_chat.py index bbb9089..9576d6d 100644 --- a/server/chat/bing_search_chat.py +++ b/server/chat/bing_search_chat.py @@ -2,7 +2,7 @@ from langchain.utilities import BingSearchAPIWrapper from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY from fastapi import Body from fastapi.responses import StreamingResponse -from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE) +from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE) from server.chat.utils import wrap_done from langchain.chat_models import ChatOpenAI from langchain import LLMChain @@ -12,7 +12,7 @@ import asyncio from langchain.prompts import PromptTemplate from langchain.docstore.document import Document -def bing_search(text, result_len=3): +def bing_search(text, result_len=SEARCH_ENGINE_TOP_K): if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY): return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV", "title": "env info is not found", @@ -33,9 +33,11 @@ def search_result2docs(search_results): def bing_search_chat(query: str = Body(..., description="用户输入", example="你好"), + top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), ): async def bing_search_chat_iterator(query: str, - ) -> AsyncIterable[str]: + top_k: int, + ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( streaming=True, @@ -46,7 +48,7 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example= model_name=LLM_MODEL ) - results = bing_search(query, result_len=3) + results = bing_search(query, result_len=top_k) docs = search_result2docs(results) context = "\n".join([doc.page_content for doc in docs]) prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"]) @@ -64,4 +66,4 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example= yield token await task - return StreamingResponse(bing_search_chat_iterator(query), media_type="text/event-stream") + return StreamingResponse(bing_search_chat_iterator(query, top_k), media_type="text/event-stream") diff --git a/server/chat/duckduckgo_search_chat.py b/server/chat/duckduckgo_search_chat.py index c10e065..b75ed25 100644 --- a/server/chat/duckduckgo_search_chat.py +++ b/server/chat/duckduckgo_search_chat.py @@ -1,7 +1,7 @@ from langchain.utilities import DuckDuckGoSearchAPIWrapper from fastapi import Body from fastapi.responses import StreamingResponse -from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE) +from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE) from server.chat.utils import wrap_done from langchain.chat_models import ChatOpenAI from langchain import LLMChain @@ -12,7 +12,7 @@ from langchain.prompts import PromptTemplate from langchain.docstore.document import Document -def duckduckgo_search(text, result_len=3): +def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K): search = DuckDuckGoSearchAPIWrapper() return search.results(text, result_len) @@ -28,8 +28,10 @@ def search_result2docs(search_results): def duckduckgo_search_chat(query: str = Body(..., description="用户输入", example="你好"), + top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), ): async def duckduckgo_search_chat_iterator(query: str, + top_k: int ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( @@ -41,7 +43,7 @@ def duckduckgo_search_chat(query: str = Body(..., description="用户输入", ex model_name=LLM_MODEL ) - results = duckduckgo_search(query, result_len=3) + results = duckduckgo_search(query, result_len=top_k) docs = search_result2docs(results) context = "\n".join([doc.page_content for doc in docs]) prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"]) @@ -59,4 +61,4 @@ def duckduckgo_search_chat(query: str = Body(..., description="用户输入", ex yield token await task - return StreamingResponse(duckduckgo_search_chat_iterator(query), media_type="text/event-stream") + return StreamingResponse(duckduckgo_search_chat_iterator(query, top_k), media_type="text/event-stream") diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index d32c497..1351e82 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,6 +1,7 @@ from fastapi import Body from fastapi.responses import StreamingResponse from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, + CACHED_VS_NUM, VECTOR_SEARCH_TOP_K, embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE) from server.chat.utils import wrap_done from langchain.chat_models import ChatOpenAI @@ -18,15 +19,15 @@ from functools import lru_cache @lru_cache(1) def load_embeddings(model: str, device: str): embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], - model_kwargs={'device': device}) + model_kwargs={'device': device}) return embeddings -@lru_cache(1) +@lru_cache(CACHED_VS_NUM) def load_vector_store( - knowledge_base_name: str, - embedding_model: str, - embedding_device: str, + knowledge_base_name: str, + embedding_model: str, + embedding_device: str, ): embeddings = load_embeddings(embedding_model, embedding_device) vs_path = get_vs_path(knowledge_base_name) @@ -35,11 +36,11 @@ def load_vector_store( def lookup_vs( - query: str, - knowledge_base_name: str, - top_k: int = 3, - embedding_model: str = EMBEDDING_MODEL, - embedding_device: str = EMBEDDING_DEVICE, + query: str, + knowledge_base_name: str, + top_k: int = VECTOR_SEARCH_TOP_K, + embedding_model: str = EMBEDDING_MODEL, + embedding_device: str = EMBEDDING_DEVICE, ): search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device) docs = search_index.similarity_search(query, k=top_k) @@ -48,7 +49,7 @@ def lookup_vs( def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"), knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), - top_k: int = Body(3, description="匹配向量数"), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), ): async def knowledge_base_chat_iterator(query: str, knowledge_base_name: str, @@ -80,4 +81,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp yield token await task - return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k), media_type="text/event-stream") + return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k), + media_type="text/event-stream")