make api and webui support top_k paramter

This commit is contained in:
liunux4odoo 2023-08-03 15:47:53 +08:00
parent df61f31c8e
commit 5c804aac75
3 changed files with 6 additions and 4 deletions

View File

@ -52,6 +52,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
): ):
async def knowledge_base_chat_iterator(query: str, async def knowledge_base_chat_iterator(query: str,
knowledge_base_name: str, knowledge_base_name: str,
top_k: int,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI( model = ChatOpenAI(
@ -62,7 +63,6 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL model_name=LLM_MODEL
) )
docs = lookup_vs(query, knowledge_base_name, top_k) docs = lookup_vs(query, knowledge_base_name, top_k)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"]) prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
@ -80,4 +80,4 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
yield token yield token
await task await task
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name), media_type="text/event-stream") return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k), media_type="text/event-stream")

View File

@ -10,6 +10,7 @@ chat_box = ChatBox(
) )
def dialogue_page(api: ApiRequest): def dialogue_page(api: ApiRequest):
chat_box.init_session()
with st.sidebar: with st.sidebar:
def on_mode_change(): def on_mode_change():
mode = st.session_state.dialogue_mode mode = st.session_state.dialogue_mode

View File

@ -254,6 +254,7 @@ class ApiRequest:
self, self,
query: str, query: str,
knowledge_base_name: str, knowledge_base_name: str,
top_k: int = 3,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -264,12 +265,12 @@ class ApiRequest:
if no_remote_api: if no_remote_api:
from server.chat.knowledge_base_chat import knowledge_base_chat from server.chat.knowledge_base_chat import knowledge_base_chat
response = knowledge_base_chat(query, knowledge_base_name) response = knowledge_base_chat(query, knowledge_base_name, top_k)
return self._fastapi_stream2generator(response) return self._fastapi_stream2generator(response)
else: else:
response = self.post( response = self.post(
"/chat/knowledge_base_chat", "/chat/knowledge_base_chat",
json={"query": query, "knowledge_base_name": knowledge_base_name}, json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k},
stream=True, stream=True,
) )
return self._httpx_stream2generator(response) return self._httpx_stream2generator(response)