make api and webui support top_k paramter

This commit is contained in:
liunux4odoo 2023-08-03 15:47:53 +08:00
parent df61f31c8e
commit 5c804aac75
3 changed files with 6 additions and 4 deletions

View File

@ -52,6 +52,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
):
async def knowledge_base_chat_iterator(query: str,
knowledge_base_name: str,
top_k: int,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
@ -62,7 +63,6 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
docs = lookup_vs(query, knowledge_base_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
@ -80,4 +80,4 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
yield token
await task
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name), media_type="text/event-stream")
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k), media_type="text/event-stream")

View File

@ -10,6 +10,7 @@ chat_box = ChatBox(
)
def dialogue_page(api: ApiRequest):
chat_box.init_session()
with st.sidebar:
def on_mode_change():
mode = st.session_state.dialogue_mode

View File

@ -254,6 +254,7 @@ class ApiRequest:
self,
query: str,
knowledge_base_name: str,
top_k: int = 3,
no_remote_api: bool = None,
):
'''
@ -264,12 +265,12 @@ class ApiRequest:
if no_remote_api:
from server.chat.knowledge_base_chat import knowledge_base_chat
response = knowledge_base_chat(query, knowledge_base_name)
response = knowledge_base_chat(query, knowledge_base_name, top_k)
return self._fastapi_stream2generator(response)
else:
response = self.post(
"/chat/knowledge_base_chat",
json={"query": query, "knowledge_base_name": knowledge_base_name},
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k},
stream=True,
)
return self._httpx_stream2generator(response)