diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 72dd251..747b685 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -13,6 +13,7 @@ import httpx import asyncio from server.chat.openai_chat import OpenAiChatMsgIn from fastapi.responses import StreamingResponse +import contextlib def set_httpx_timeout(timeout=60.0): @@ -173,7 +174,7 @@ class ApiRequest: except: retry -= 1 - def _stream2generator(self, response: StreamingResponse): + def _fastapi_stream2generator(self, response: StreamingResponse): ''' 将api.py中视图函数返回的StreamingResponse转化为同步生成器 ''' @@ -183,6 +184,14 @@ class ApiRequest: loop = asyncio.new_event_loop() return iter_over_async(response.body_iterator, loop) + def _httpx_stream2generator(self,response: contextlib._GeneratorContextManager): + ''' + 将httpx.stream返回的GeneratorContextManager转化为普通生成器 + ''' + with response as r: + for chunk in r.iter_text(None): + yield chunk + # 对话相关操作 def chat_fastchat( @@ -212,7 +221,7 @@ class ApiRequest: if no_remote_api: from server.chat.openai_chat import openai_chat response = openai_chat(msg) - return self._stream2generator(response) + return self._fastapi_stream2generator(response) else: data = msg.dict(exclude_unset=True, exclude_none=True) response = self.post( @@ -220,7 +229,7 @@ class ApiRequest: json=data, stream=stream, ) - return response + return self._httpx_stream2generator(response) def chat_chat( self, @@ -236,10 +245,10 @@ class ApiRequest: if no_remote_api: from server.chat.chat import chat response = chat(query) - return self._stream2generator(response) + return self._fastapi_stream2generator(response) else: response = self.post("/chat/chat", json=f"{query}", stream=True) - return response + return self._httpx_stream2generator(response) def knowledge_base_chat( self, @@ -256,14 +265,14 @@ class ApiRequest: if no_remote_api: from server.chat.knowledge_base_chat import knowledge_base_chat response = knowledge_base_chat(query, knowledge_base_name) - return self._stream2generator(response) + return self._fastapi_stream2generator(response) else: response = self.post( "/chat/knowledge_base_chat", json={"query": query, "knowledge_base_name": knowledge_base_name}, stream=True, ) - return response + return self._httpx_stream2generator(response) def duckduckgo_search_chat( self, @@ -279,14 +288,14 @@ class ApiRequest: if no_remote_api: from server.chat.duckduckgo_search_chat import duckduckgo_search_chat response = duckduckgo_search_chat(query) - return self._stream2generator(response) + return self._fastapi_stream2generator(response) else: response = self.post( "/chat/duckduckgo_search_chat", json=f"{query}", stream=True, ) - return response + return self._httpx_stream2generator(response) def bing_search_chat( self, @@ -302,14 +311,14 @@ class ApiRequest: if no_remote_api: from server.chat.bing_search_chat import bing_search_chat response = bing_search_chat(query) - return self._stream2generator(response) + return self._fastapi_stream2generator(response) else: response = self.post( "/chat/bing_search_chat", json=f"{query}", stream=True, ) - return response + return self._httpx_stream2generator(response) # 知识库相关操作