add top-k to search_chat apis and add top-k params in model_config
This commit is contained in:
parent
5c804aac75
commit
b62ea6bd2a
|
|
@ -2,7 +2,7 @@ from langchain.utilities import BingSearchAPIWrapper
|
|||
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE)
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
|
||||
from server.chat.utils import wrap_done
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
|
|
@ -12,7 +12,7 @@ import asyncio
|
|||
from langchain.prompts import PromptTemplate
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
def bing_search(text, result_len=3):
|
||||
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
||||
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||||
"title": "env info is not found",
|
||||
|
|
@ -33,9 +33,11 @@ def search_result2docs(search_results):
|
|||
|
||||
|
||||
def bing_search_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
):
|
||||
async def bing_search_chat_iterator(query: str,
|
||||
) -> AsyncIterable[str]:
|
||||
top_k: int,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
|
|
@ -46,7 +48,7 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example=
|
|||
model_name=LLM_MODEL
|
||||
)
|
||||
|
||||
results = bing_search(query, result_len=3)
|
||||
results = bing_search(query, result_len=top_k)
|
||||
docs = search_result2docs(results)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
||||
|
|
@ -64,4 +66,4 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example=
|
|||
yield token
|
||||
await task
|
||||
|
||||
return StreamingResponse(bing_search_chat_iterator(query), media_type="text/event-stream")
|
||||
return StreamingResponse(bing_search_chat_iterator(query, top_k), media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from langchain.utilities import DuckDuckGoSearchAPIWrapper
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE)
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
|
||||
from server.chat.utils import wrap_done
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
|
|
@ -12,7 +12,7 @@ from langchain.prompts import PromptTemplate
|
|||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
def duckduckgo_search(text, result_len=3):
|
||||
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||
search = DuckDuckGoSearchAPIWrapper()
|
||||
return search.results(text, result_len)
|
||||
|
||||
|
|
@ -28,8 +28,10 @@ def search_result2docs(search_results):
|
|||
|
||||
|
||||
def duckduckgo_search_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
):
|
||||
async def duckduckgo_search_chat_iterator(query: str,
|
||||
top_k: int
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
|
|
@ -41,7 +43,7 @@ def duckduckgo_search_chat(query: str = Body(..., description="用户输入", ex
|
|||
model_name=LLM_MODEL
|
||||
)
|
||||
|
||||
results = duckduckgo_search(query, result_len=3)
|
||||
results = duckduckgo_search(query, result_len=top_k)
|
||||
docs = search_result2docs(results)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
||||
|
|
@ -59,4 +61,4 @@ def duckduckgo_search_chat(query: str = Body(..., description="用户输入", ex
|
|||
yield token
|
||||
await task
|
||||
|
||||
return StreamingResponse(duckduckgo_search_chat_iterator(query), media_type="text/event-stream")
|
||||
return StreamingResponse(duckduckgo_search_chat_iterator(query, top_k), media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
CACHED_VS_NUM, VECTOR_SEARCH_TOP_K,
|
||||
embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
|
||||
from server.chat.utils import wrap_done
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
|
@ -18,15 +19,15 @@ from functools import lru_cache
|
|||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device})
|
||||
model_kwargs={'device': device})
|
||||
return embeddings
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
@lru_cache(CACHED_VS_NUM)
|
||||
def load_vector_store(
|
||||
knowledge_base_name: str,
|
||||
embedding_model: str,
|
||||
embedding_device: str,
|
||||
knowledge_base_name: str,
|
||||
embedding_model: str,
|
||||
embedding_device: str,
|
||||
):
|
||||
embeddings = load_embeddings(embedding_model, embedding_device)
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
|
|
@ -35,11 +36,11 @@ def load_vector_store(
|
|||
|
||||
|
||||
def lookup_vs(
|
||||
query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int = 3,
|
||||
embedding_model: str = EMBEDDING_MODEL,
|
||||
embedding_device: str = EMBEDDING_DEVICE,
|
||||
query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
embedding_model: str = EMBEDDING_MODEL,
|
||||
embedding_device: str = EMBEDDING_DEVICE,
|
||||
):
|
||||
search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device)
|
||||
docs = search_index.similarity_search(query, k=top_k)
|
||||
|
|
@ -48,7 +49,7 @@ def lookup_vs(
|
|||
|
||||
def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
||||
top_k: int = Body(3, description="匹配向量数"),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
):
|
||||
async def knowledge_base_chat_iterator(query: str,
|
||||
knowledge_base_name: str,
|
||||
|
|
@ -80,4 +81,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
yield token
|
||||
await task
|
||||
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k), media_type="text/event-stream")
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k),
|
||||
media_type="text/event-stream")
|
||||
|
|
|
|||
Loading…
Reference in New Issue