From 1de4258aa010f58749d4fbfe6e4064ac76ca5a02 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Sat, 25 Nov 2023 13:51:07 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=99=20ApiModelWorker=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=20logger=20=E6=88=90=E5=91=98=E5=8F=98=E9=87=8F=EF=BC=8CAPI?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=E5=87=BA=E9=94=99=E6=97=B6=E8=BE=93=E5=87=BA?= =?UTF-8?q?=E6=9C=89=E6=84=8F=E4=B9=89=E7=9A=84=E9=94=99=E8=AF=AF=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=E3=80=82=20(#2169)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。 /chat/chat 接口 conversation_id参数改为默认 "",避免 swagger 页面默认值错误导致历史消息失效 * 修复在线模型一些bug --- server/chat/chat.py | 4 +-- server/model_workers/azure.py | 6 +++-- server/model_workers/baichuan.py | 4 ++- server/model_workers/base.py | 6 ++--- server/model_workers/fangzhou.py | 42 ++++++++++++++++++++------------ server/model_workers/minimax.py | 8 ++++-- server/model_workers/qianfan.py | 8 ++++-- server/model_workers/qwen.py | 9 ++++--- server/model_workers/tiangong.py | 9 +++---- server/model_workers/xinghuo.py | 2 +- server/model_workers/zhipu.py | 9 +++++-- tests/test_online_api.py | 36 +++++++++++++-------------- 12 files changed, 85 insertions(+), 58 deletions(-) diff --git a/server/chat/chat.py b/server/chat/chat.py index 16b271d..acf3ec0 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -18,7 +18,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), - conversation_id: str = Body(None, description="对话框ID"), + conversation_id: str = Body("", description="对话框ID"), history: Union[int, List[History]] = Body([], description="历史对话,设为一个整数可以从数据库中读取历史消息", examples=[[ @@ -54,7 +54,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callbacks=callbacks, ) - if conversation_id is None: + if not conversation_id: history = [History.from_data(h) for h in history] prompt_template = get_prompt_template("llm_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) diff --git a/server/model_workers/azure.py b/server/model_workers/azure.py index 46d9529..c6b4cbb 100644 --- a/server/model_workers/azure.py +++ b/server/model_workers/azure.py @@ -12,8 +12,8 @@ class AzureWorker(ApiModelWorker): def __init__( self, *, - controller_addr: str, - worker_addr: str, + controller_addr: str = None, + worker_addr: str = None, model_names: List[str] = ["azure-api"], version: str = "gpt-35-turbo", **kwargs, @@ -60,6 +60,8 @@ class AzureWorker(ApiModelWorker): "error_code": 0, "text": text } + else: + self.logger.error(f"请求 Azure API 时发生错误:{resp}") def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/model_workers/baichuan.py b/server/model_workers/baichuan.py index edc81eb..5e9cbbb 100644 --- a/server/model_workers/baichuan.py +++ b/server/model_workers/baichuan.py @@ -74,7 +74,7 @@ class BaiChuanWorker(ApiModelWorker): "text": text } else: - yield { + data = { "error_code": resp["code"], "text": resp["msg"], "error": { @@ -84,6 +84,8 @@ class BaiChuanWorker(ApiModelWorker): "code": None, } } + self.logger.error(f"请求百川 API 时发生错误:{data}") + yield data def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/model_workers/base.py b/server/model_workers/base.py index e015ce3..7b456a9 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -15,10 +15,6 @@ from typing import Dict, List, Optional __all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"] -# 恢复被fastchat覆盖的标准输出 -sys.stdout = sys.__stdout__ -sys.stderr = sys.__stderr__ - class ApiConfigParams(BaseModel): ''' @@ -110,7 +106,9 @@ class ApiModelWorker(BaseModelWorker): controller_addr=controller_addr, worker_addr=worker_addr, **kwargs) + import fastchat.serve.base_model_worker import sys + self.logger = fastchat.serve.base_model_worker.logger # 恢复被fastchat覆盖的标准输出 sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ diff --git a/server/model_workers/fangzhou.py b/server/model_workers/fangzhou.py index 3834b5a..ddbad4a 100644 --- a/server/model_workers/fangzhou.py +++ b/server/model_workers/fangzhou.py @@ -48,23 +48,33 @@ class FangZhouWorker(ApiModelWorker): text = "" if log_verbose: - logger.info(f'{self.__class__.__name__}:maas: {maas}') + self.logger.info(f'{self.__class__.__name__}:maas: {maas}') for resp in maas.stream_chat(req): - error = resp.error - if error.code_n > 0: - yield { - "error_code": error.code_n, - "text": error.message, - "error": { - "message": error.message, - "type": "invalid_request_error", - "param": None, - "code": None, - } - } - elif chunk := resp.choice.message.content: - text += chunk - yield {"error_code": 0, "text": text} + if error := resp.error: + if error.code_n > 0: + data = { + "error_code": error.code_n, + "text": error.message, + "error": { + "message": error.message, + "type": "invalid_request_error", + "param": None, + "code": None, + } + } + self.logger.error(f"请求方舟 API 时发生错误:{data}") + yield data + elif chunk := resp.choice.message.content: + text += chunk + yield {"error_code": 0, "text": text} + else: + data = { + "error_code": 500, + "text": f"请求方舟 API 时发生未知的错误: {resp}" + } + self.logger.error(data) + yield data + break def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index 220ed58..45e9a5f 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -74,7 +74,7 @@ class MiniMaxWorker(ApiModelWorker): text = "" for e in r.iter_text(): if not e.startswith("data: "): # 真是优秀的返回 - yield { + data = { "error_code": 500, "text": f"minimax返回错误的结果:{e}", "error": { @@ -84,6 +84,8 @@ class MiniMaxWorker(ApiModelWorker): "code": None, } } + self.logger.error(f"请求 MiniMax API 时发生错误:{data}") + yield data continue data = json.loads(e[6:]) @@ -123,7 +125,7 @@ class MiniMaxWorker(ApiModelWorker): if embeddings := r.get("vectors"): result += embeddings elif error := r.get("base_resp"): - return { + data = { "code": error["status_code"], "msg": error["status_msg"], "error": { @@ -133,6 +135,8 @@ class MiniMaxWorker(ApiModelWorker): "code": None, } } + self.logger.error(f"请求 MiniMax API 时发生错误:{data}") + return data i += 10 return {"code": 200, "data": embeddings} diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 2ab39d1..95e200a 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -154,7 +154,7 @@ class QianFanWorker(ApiModelWorker): "text": text } else: - yield { + data = { "error_code": resp["error_code"], "text": resp["error_msg"], "error": { @@ -164,6 +164,8 @@ class QianFanWorker(ApiModelWorker): "code": None, } } + self.logger.error(f"请求千帆 API 时发生错误:{data}") + yield data def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: params.load_config(self.model_names[0]) @@ -189,7 +191,7 @@ class QianFanWorker(ApiModelWorker): for texts in params.texts[i:i+10]: resp = client.post(url, json={"input": texts}).json() if "error_cdoe" in resp: - return { + data = { "code": resp["error_code"], "msg": resp["error_msg"], "error": { @@ -199,6 +201,8 @@ class QianFanWorker(ApiModelWorker): "code": None, } } + self.logger.error(f"请求千帆 API 时发生错误:{data}") + return data else: embeddings = [x["embedding"] for x in resp.get("data", [])] result += embeddings diff --git a/server/model_workers/qwen.py b/server/model_workers/qwen.py index fde0a4c..58d1bcd 100644 --- a/server/model_workers/qwen.py +++ b/server/model_workers/qwen.py @@ -53,7 +53,7 @@ class QwenWorker(ApiModelWorker): "text": choices[0]["message"]["content"], } else: - yield { + data = { "error_code": resp["status_code"], "text": resp["message"], "error": { @@ -63,7 +63,8 @@ class QwenWorker(ApiModelWorker): "code": None, } } - + self.logger.error(f"请求千问 API 时发生错误:{data}") + yield data def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: import dashscope @@ -80,7 +81,7 @@ class QwenWorker(ApiModelWorker): api_key=params.api_key, ) if resp["status_code"] != 200: - return { + data = { "code": resp["status_code"], "msg": resp.message, "error": { @@ -90,6 +91,8 @@ class QwenWorker(ApiModelWorker): "code": None, } } + self.logger.error(f"请求千问 API 时发生错误:{data}") + return data else: embeddings = [x["embedding"] for x in resp["output"]["embeddings"]] result += embeddings diff --git a/server/model_workers/tiangong.py b/server/model_workers/tiangong.py index 5ff3de0..85a763f 100644 --- a/server/model_workers/tiangong.py +++ b/server/model_workers/tiangong.py @@ -6,7 +6,6 @@ from fastchat.conversation import Conversation from server.model_workers.base import * from server.utils import get_httpx_client from fastchat import conversation as conv -import sys import json from typing import List, Literal, Dict import requests @@ -64,12 +63,12 @@ class TianGongWorker(ApiModelWorker): "text": text } else: - yield { + data = { "error_code": resp["code"], "text": resp["code_msg"] - } - - + } + self.logger.error(f"请求天工 API 时出错:{data}") + yield data def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py index e866609..72db738 100644 --- a/server/model_workers/xinghuo.py +++ b/server/model_workers/xinghuo.py @@ -62,7 +62,7 @@ class XingHuoWorker(ApiModelWorker): loop = asyncio.get_event_loop() except: loop = asyncio.new_event_loop() - params.max_tokens = min(details["max_tokens"], params.max_tokens) + params.max_tokens = min(details["max_tokens"], params.max_tokens or 0) for chunk in iter_over_async( request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages, params.temperature, params.max_tokens), diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index cafb114..4341937 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -44,7 +44,7 @@ class ChatGLMWorker(ApiModelWorker): if e.event == "add": yield {"error_code": 0, "text": e.data} elif e.event in ["error", "interrupted"]: - yield { + data = { "error_code": 500, "text": str(e), "error": { @@ -54,6 +54,8 @@ class ChatGLMWorker(ApiModelWorker): "code": None, } } + self.logger.error(f"请求智谱 API 时发生错误:{data}") + yield data def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: import zhipuai @@ -68,9 +70,12 @@ class ChatGLMWorker(ApiModelWorker): if response["code"] == 200: embeddings.append(response["data"]["embedding"]) else: + self.logger.error(f"请求智谱 API 时发生错误:{response}") return response # dict with code & msg except Exception as e: - return {"code": 500, "msg": f"对文本向量化时出错:{e}"} + self.logger.error(f"请求智谱 API 时发生错误:{data}") + data = {"code": 500, "msg": f"对文本向量化时出错:{e}"} + return data return {"code": 200, "data": embeddings} diff --git a/tests/test_online_api.py b/tests/test_online_api.py index 7c72f3e..b33d134 100644 --- a/tests/test_online_api.py +++ b/tests/test_online_api.py @@ -16,7 +16,7 @@ for x in list_config_llm_models()["online"]: workers.append(x) print(f"all workers to test: {workers}") -# workers = ["qianfan-api"] +# workers = ["fangzhou-api"] @pytest.mark.parametrize("worker", workers) @@ -28,11 +28,11 @@ def test_chat(worker): ) print(f"\nchat with {worker} \n") - worker_class = get_model_worker_config(worker)["worker_class"] - for x in worker_class().do_chat(params): - pprint(x) - assert isinstance(x, dict) - assert x["error_code"] == 0 + if worker_class := get_model_worker_config(worker).get("worker_class"): + for x in worker_class().do_chat(params): + pprint(x) + assert isinstance(x, dict) + assert x["error_code"] == 0 @pytest.mark.parametrize("worker", workers) @@ -44,19 +44,19 @@ def test_embeddings(worker): ] ) - worker_class = get_model_worker_config(worker)["worker_class"] - if worker_class.can_embedding(): - print(f"\embeddings with {worker} \n") - resp = worker_class().do_embeddings(params) + if worker_class := get_model_worker_config(worker).get("worker_class"): + if worker_class.can_embedding(): + print(f"\embeddings with {worker} \n") + resp = worker_class().do_embeddings(params) - pprint(resp, depth=2) - assert resp["code"] == 200 - assert "data" in resp - embeddings = resp["data"] - assert isinstance(embeddings, list) and len(embeddings) > 0 - assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0 - assert isinstance(embeddings[0][0], float) - print("向量长度:", len(embeddings[0])) + pprint(resp, depth=2) + assert resp["code"] == 200 + assert "data" in resp + embeddings = resp["data"] + assert isinstance(embeddings, list) and len(embeddings) > 0 + assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0 + assert isinstance(embeddings[0][0], float) + print("向量长度:", len(embeddings[0])) # @pytest.mark.parametrize("worker", workers)