diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 852e9ca..d32c497 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -52,6 +52,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp ): async def knowledge_base_chat_iterator(query: str, knowledge_base_name: str, + top_k: int, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( @@ -62,7 +63,6 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], model_name=LLM_MODEL ) - docs = lookup_vs(query, knowledge_base_name, top_k) context = "\n".join([doc.page_content for doc in docs]) prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"]) @@ -80,4 +80,4 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp yield token await task - return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name), media_type="text/event-stream") + return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k), media_type="text/event-stream") diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index ea0f3b6..097b50f 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -10,6 +10,7 @@ chat_box = ChatBox( ) def dialogue_page(api: ApiRequest): + chat_box.init_session() with st.sidebar: def on_mode_change(): mode = st.session_state.dialogue_mode diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 747b685..2089179 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -254,6 +254,7 @@ class ApiRequest: self, query: str, knowledge_base_name: str, + top_k: int = 3, no_remote_api: bool = None, ): ''' @@ -264,12 +265,12 @@ class ApiRequest: if no_remote_api: from server.chat.knowledge_base_chat import knowledge_base_chat - response = knowledge_base_chat(query, knowledge_base_name) + response = knowledge_base_chat(query, knowledge_base_name, top_k) return self._fastapi_stream2generator(response) else: response = self.post( "/chat/knowledge_base_chat", - json={"query": query, "knowledge_base_name": knowledge_base_name}, + json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k}, stream=True, ) return self._httpx_stream2generator(response)