diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 6746750..7eb7be2 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -20,13 +20,14 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), + description="历史对话", + examples=[[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}]] + ), + stream: bool = Body(False, description="流式输出"), ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: @@ -67,11 +68,19 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp for inum, doc in enumerate(docs) ] - async for token in callback.aiter(): - # Use server-sent-events to stream the response + if stream: + async for token in callback.aiter(): + # Use server-sent-events to stream the response + yield json.dumps({"answer": token, + "docs": source_documents}, ensure_ascii=False) + else: + answer = "" + async for token in callback.aiter(): + answer += token yield json.dumps({"answer": token, - "docs": source_documents}, - ensure_ascii=False) + "docs": source_documents}, + ensure_ascii=False) + await task return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history), diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 95b2bd0..04cf933 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -61,13 +61,14 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), + description="历史对话", + examples=[[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}]] + ), + stream: bool = Body(False, description="流式输出"), ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -106,8 +107,15 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl for inum, doc in enumerate(docs) ] - async for token in callback.aiter(): - # Use server-sent-events to stream the response + if stream: + async for token in callback.aiter(): + # Use server-sent-events to stream the response + yield json.dumps({"answer": token, + "docs": source_documents}, ensure_ascii=False) + else: + answer = "" + async for token in callback.aiter(): + answer += token yield json.dumps({"answer": token, "docs": source_documents}, ensure_ascii=False) diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 89d1b78..420414f 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -253,6 +253,7 @@ class ApiRequest: knowledge_base_name: str, top_k: int = VECTOR_SEARCH_TOP_K, history: List[Dict] = [], + stream: bool = True, no_remote_api: bool = None, ): ''' @@ -261,14 +262,22 @@ class ApiRequest: if no_remote_api is None: no_remote_api = self.no_remote_api + data = { + "query": query, + "knowledge_base_name": knowledge_base_name, + "top_k": top_k, + "history": history, + "stream": stream, + } + if no_remote_api: from server.chat.knowledge_base_chat import knowledge_base_chat - response = knowledge_base_chat(query, knowledge_base_name, top_k, history) + response = knowledge_base_chat(**data) return self._fastapi_stream2generator(response, as_json=True) else: response = self.post( "/chat/knowledge_base_chat", - json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, "history": history}, + json=data, stream=True, ) return self._httpx_stream2generator(response, as_json=True) @@ -278,6 +287,7 @@ class ApiRequest: query: str, search_engine_name: str, top_k: int = SEARCH_ENGINE_TOP_K, + stream: bool = True, no_remote_api: bool = None, ): ''' @@ -286,14 +296,21 @@ class ApiRequest: if no_remote_api is None: no_remote_api = self.no_remote_api + data = { + "query": query, + "search_engine_name": search_engine_name, + "top_k": top_k, + "stream": stream, + } + if no_remote_api: from server.chat.search_engine_chat import search_engine_chat - response = search_engine_chat(query, search_engine_name, top_k) + response = search_engine_chat(**data) return self._fastapi_stream2generator(response, as_json=True) else: response = self.post( "/chat/search_engine_chat", - json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k}, + json=data, stream=True, ) return self._httpx_stream2generator(response, as_json=True) @@ -331,14 +348,20 @@ class ApiRequest: if no_remote_api is None: no_remote_api = self.no_remote_api + data = { + "knowledge_base_name": knowledge_base_name, + "vector_store_type": vector_store_type, + "embed_model": embed_model, + } + if no_remote_api: from server.knowledge_base.kb_api import create_kb - response = run_async(create_kb(knowledge_base_name, vector_store_type, embed_model)) + response = run_async(create_kb(**data)) return response.dict() else: response = self.post( "/knowledge_base/create_knowledge_base", - json={"knowledge_base_name": knowledge_base_name, "vector_store_type": vector_store_type, "embed_model": embed_model}, + json=data, ) return response.json() @@ -443,14 +466,20 @@ class ApiRequest: if no_remote_api is None: no_remote_api = self.no_remote_api + data = { + "knowledge_base_name": knowledge_base_name, + "doc_name": doc_name, + "delete_content": delete_content, + } + if no_remote_api: from server.knowledge_base.kb_doc_api import delete_doc - response = run_async(delete_doc(knowledge_base_name, doc_name, delete_content)) + response = run_async(delete_doc(**data)) return response.dict() else: response = self.delete( "/knowledge_base/delete_doc", - json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name, "delete_content": delete_content}, + json=data, ) return response.json()