diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py index 7b69265..a7ad807 100644 --- a/server/chat/openai_chat.py +++ b/server/chat/openai_chat.py @@ -1,7 +1,7 @@ from fastapi.responses import StreamingResponse from typing import List import openai -from configs.model_config import llm_model_dict, LLM_MODEL +from configs.model_config import llm_model_dict, LLM_MODEL, logger from pydantic import BaseModel @@ -33,19 +33,23 @@ async def openai_chat(msg: OpenAiChatMsgIn): data = msg.dict() data["streaming"] = True data.pop("stream") - response = openai.ChatCompletion.create(**data) - if msg.stream: - for chunk in response.choices[0].message.content: - print(chunk) - yield chunk - else: - answer = "" - for chunk in response.choices[0].message.content: - answer += chunk - print(answer) - yield(answer) - + try: + response = openai.ChatCompletion.create(**data) + if msg.stream: + for chunk in response.choices[0].message.content: + print(chunk) + yield chunk + else: + answer = "" + for chunk in response.choices[0].message.content: + answer += chunk + print(answer) + yield(answer) + except Exception as e: + print(type(e)) + logger.error(e) + return StreamingResponse( get_response(msg), media_type='text/event-stream', diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index a39a196..89e148f 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -98,6 +98,9 @@ def dialogue_page(api: ApiRequest): text = "" r = api.chat_chat(prompt, history) for t in r: + if error_msg := check_error_msg(t): # check whether error occured + st.error(error_msg) + break text += t chat_box.update_msg(text) chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 @@ -109,6 +112,8 @@ def dialogue_page(api: ApiRequest): ]) text = "" for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history): + if error_msg := check_error_msg(t): # check whether error occured + st.error(error_msg) text += d["answer"] chat_box.update_msg(text, 0) chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) @@ -120,6 +125,8 @@ def dialogue_page(api: ApiRequest): ]) text = "" for d in api.search_engine_chat(prompt, search_engine, se_top_k): + if error_msg := check_error_msg(t): # check whether error occured + st.error(error_msg) text += d["answer"] chat_box.update_msg(text, 0) chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index f963a67..47eb24c 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -119,8 +119,6 @@ def knowledge_base_page(api: ApiRequest): elif selected_kb: kb = selected_kb["kb_name"] - - # 上传文件 # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) files = st.file_uploader("上传知识文件", diff --git a/webui_pages/utils.py b/webui_pages/utils.py index d2772d5..c721e97 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -8,6 +8,7 @@ from configs.model_config import ( LLM_MODEL, VECTOR_SEARCH_TOP_K, SEARCH_ENGINE_TOP_K, + logger, ) import httpx import asyncio @@ -24,6 +25,7 @@ from configs.model_config import NLTK_DATA_PATH import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path + def set_httpx_timeout(timeout=60.0): ''' 设置httpx默认timeout到60秒。 @@ -80,7 +82,8 @@ class ApiRequest: return httpx.stream("GET", url, params=params, **kwargs) else: return httpx.get(url, params=params, **kwargs) - except: + except Exception as e: + logger.error(e) retry -= 1 async def aget( @@ -100,7 +103,8 @@ class ApiRequest: return await client.stream("GET", url, params=params, **kwargs) else: return await client.get(url, params=params, **kwargs) - except: + except Exception as e: + logger.error(e) retry -= 1 def post( @@ -121,7 +125,8 @@ class ApiRequest: return httpx.stream("POST", url, data=data, json=json, **kwargs) else: return httpx.post(url, data=data, json=json, **kwargs) - except: + except Exception as e: + logger.error(e) retry -= 1 async def apost( @@ -142,7 +147,8 @@ class ApiRequest: return await client.stream("POST", url, data=data, json=json, **kwargs) else: return await client.post(url, data=data, json=json, **kwargs) - except: + except Exception as e: + logger.error(e) retry -= 1 def delete( @@ -162,7 +168,8 @@ class ApiRequest: return httpx.stream("DELETE", url, data=data, json=json, **kwargs) else: return httpx.delete(url, data=data, json=json, **kwargs) - except: + except Exception as e: + logger.error(e) retry -= 1 async def adelete( @@ -183,7 +190,8 @@ class ApiRequest: return await client.stream("DELETE", url, data=data, json=json, **kwargs) else: return await client.delete(url, data=data, json=json, **kwargs) - except: + except Exception as e: + logger.error(e) retry -= 1 def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False): @@ -195,11 +203,14 @@ class ApiRequest: except: loop = asyncio.new_event_loop() - for chunk in iter_over_async(response.body_iterator, loop): - if as_json and chunk: - yield json.loads(chunk) - elif chunk.strip(): - yield chunk + try: + for chunk in iter_over_async(response.body_iterator, loop): + if as_json and chunk: + yield json.loads(chunk) + elif chunk.strip(): + yield chunk + except Exception as e: + logger.error(e) def _httpx_stream2generator( self, @@ -209,12 +220,31 @@ class ApiRequest: ''' 将httpx.stream返回的GeneratorContextManager转化为普通生成器 ''' - with response as r: - for chunk in r.iter_text(None): - if as_json and chunk: - yield json.loads(chunk) - elif chunk.strip(): - yield chunk + try: + with response as r: + for chunk in r.iter_text(None): + if not chunk: # openai api server communicating error + msg = f"API通信超时,请确认已启动FastChat与API服务(详见README '5. 启动 API 服务或 Web UI')" + logger.error(msg) + yield {"code": 500, "errorMsg": msg} + break + if as_json and chunk: + yield json.loads(chunk) + elif chunk.strip(): + yield chunk + except httpx.ConnectError as e: + msg = f"无法连接API服务器,请确认已执行python server\\api.py" + logger.error(msg) + logger.error(e) + yield {"code": 500, "errorMsg": msg} + except httpx.ReadTimeout as e: + msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')" + logger.error(msg) + logger.error(e) + yield {"code": 500, "errorMsg": msg} + except Exception as e: + logger.error(e) + yield {"code": 500, "errorMsg": str(e)} # 对话相关操作 @@ -353,6 +383,21 @@ class ApiRequest: # 知识库相关操作 + def _check_httpx_json_response( + self, + response: httpx.Response, + errorMsg: str = f"无法连接API服务器,请确认已执行python server\\api.py", + ) -> Dict: + ''' + check whether httpx returns correct data with normal Response. + error in api with streaming support was checked in _httpx_stream2enerator + ''' + try: + return response.json() + except Exception as e: + logger.error(e) + return {"code": 500, "errorMsg": errorMsg or str(e)} + def list_knowledge_bases( self, no_remote_api: bool = None, @@ -369,7 +414,8 @@ class ApiRequest: return response.data else: response = self.get("/knowledge_base/list_knowledge_bases") - return response.json().get("data") + data = self._check_httpx_json_response(response) + return data.get("data", []) def create_knowledge_base( self, @@ -399,7 +445,7 @@ class ApiRequest: "/knowledge_base/create_knowledge_base", json=data, ) - return response.json() + return self._check_httpx_json_response(response) def delete_knowledge_base( self, @@ -421,7 +467,7 @@ class ApiRequest: "/knowledge_base/delete_knowledge_base", json=f"{knowledge_base_name}", ) - return response.json() + return self._check_httpx_json_response(response) def list_kb_docs( self, @@ -443,7 +489,8 @@ class ApiRequest: "/knowledge_base/list_docs", params={"knowledge_base_name": knowledge_base_name} ) - return response.json().get("data") + data = self._check_httpx_json_response(response) + return data.get("data", []) def upload_kb_doc( self, @@ -487,7 +534,7 @@ class ApiRequest: data={"knowledge_base_name": knowledge_base_name, "override": override}, files={"file": (filename, file)}, ) - return response.json() + return self._check_httpx_json_response(response) def delete_kb_doc( self, @@ -517,7 +564,7 @@ class ApiRequest: "/knowledge_base/delete_doc", json=data, ) - return response.json() + return self._check_httpx_json_response(response) def update_kb_doc( self, @@ -540,7 +587,7 @@ class ApiRequest: "/knowledge_base/update_doc", json={"knowledge_base_name": knowledge_base_name, "file_name": file_name}, ) - return response.json() + return self._check_httpx_json_response(response) def recreate_vector_store( self, @@ -572,10 +619,20 @@ class ApiRequest: "/knowledge_base/recreate_vector_store", json=data, stream=True, + timeout=False, ) return self._httpx_stream2generator(response, as_json=True) +def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: + ''' + return error message if error occured when requests API + ''' + if isinstance(data, dict) and key in data: + return data[key] + return "" + + if __name__ == "__main__": api = ApiRequest(no_remote_api=True)