add top-k to search_chat apis and add top-k params in model_config

This commit is contained in:
imClumsyPanda 2023-08-03 17:06:43 +08:00
parent 5c804aac75
commit b62ea6bd2a
3 changed files with 27 additions and 21 deletions

View File

@ -2,7 +2,7 @@ from langchain.utilities import BingSearchAPIWrapper
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE)
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
from server.chat.utils import wrap_done
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
@ -12,7 +12,7 @@ import asyncio
from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document
def bing_search(text, result_len=3):
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env info is not found",
@ -33,9 +33,11 @@ def search_result2docs(search_results):
def bing_search_chat(query: str = Body(..., description="用户输入", example="你好"),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
):
async def bing_search_chat_iterator(query: str,
) -> AsyncIterable[str]:
top_k: int,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
@ -46,7 +48,7 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example=
model_name=LLM_MODEL
)
results = bing_search(query, result_len=3)
results = bing_search(query, result_len=top_k)
docs = search_result2docs(results)
context = "\n".join([doc.page_content for doc in docs])
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
@ -64,4 +66,4 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example=
yield token
await task
return StreamingResponse(bing_search_chat_iterator(query), media_type="text/event-stream")
return StreamingResponse(bing_search_chat_iterator(query, top_k), media_type="text/event-stream")

View File

@ -1,7 +1,7 @@
from langchain.utilities import DuckDuckGoSearchAPIWrapper
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE)
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
from server.chat.utils import wrap_done
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
@ -12,7 +12,7 @@ from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document
def duckduckgo_search(text, result_len=3):
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
search = DuckDuckGoSearchAPIWrapper()
return search.results(text, result_len)
@ -28,8 +28,10 @@ def search_result2docs(search_results):
def duckduckgo_search_chat(query: str = Body(..., description="用户输入", example="你好"),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
):
async def duckduckgo_search_chat_iterator(query: str,
top_k: int
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
@ -41,7 +43,7 @@ def duckduckgo_search_chat(query: str = Body(..., description="用户输入", ex
model_name=LLM_MODEL
)
results = duckduckgo_search(query, result_len=3)
results = duckduckgo_search(query, result_len=top_k)
docs = search_result2docs(results)
context = "\n".join([doc.page_content for doc in docs])
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
@ -59,4 +61,4 @@ def duckduckgo_search_chat(query: str = Body(..., description="用户输入", ex
yield token
await task
return StreamingResponse(duckduckgo_search_chat_iterator(query), media_type="text/event-stream")
return StreamingResponse(duckduckgo_search_chat_iterator(query, top_k), media_type="text/event-stream")

View File

@ -1,6 +1,7 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
CACHED_VS_NUM, VECTOR_SEARCH_TOP_K,
embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
from server.chat.utils import wrap_done
from langchain.chat_models import ChatOpenAI
@ -18,15 +19,15 @@ from functools import lru_cache
@lru_cache(1)
def load_embeddings(model: str, device: str):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device})
model_kwargs={'device': device})
return embeddings
@lru_cache(1)
@lru_cache(CACHED_VS_NUM)
def load_vector_store(
knowledge_base_name: str,
embedding_model: str,
embedding_device: str,
knowledge_base_name: str,
embedding_model: str,
embedding_device: str,
):
embeddings = load_embeddings(embedding_model, embedding_device)
vs_path = get_vs_path(knowledge_base_name)
@ -35,11 +36,11 @@ def load_vector_store(
def lookup_vs(
query: str,
knowledge_base_name: str,
top_k: int = 3,
embedding_model: str = EMBEDDING_MODEL,
embedding_device: str = EMBEDDING_DEVICE,
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
embedding_model: str = EMBEDDING_MODEL,
embedding_device: str = EMBEDDING_DEVICE,
):
search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device)
docs = search_index.similarity_search(query, k=top_k)
@ -48,7 +49,7 @@ def lookup_vs(
def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"),
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
top_k: int = Body(3, description="匹配向量数"),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
):
async def knowledge_base_chat_iterator(query: str,
knowledge_base_name: str,
@ -80,4 +81,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
yield token
await task
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k), media_type="text/event-stream")
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k),
media_type="text/event-stream")