给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。 (#2169)

* 给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。
/chat/chat 接口 conversation_id参数改为默认 "",避免 swagger 页面默认值错误导致历史消息失效

* 修复在线模型一些bug
This commit is contained in:
liunux4odoo 2023-11-25 13:51:07 +08:00 committed by GitHub
parent 1b0cf67a57
commit 1de4258aa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 85 additions and 58 deletions

View File

@ -18,7 +18,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body(None, description="对话框ID"), conversation_id: str = Body("", description="对话框ID"),
history: Union[int, List[History]] = Body([], history: Union[int, List[History]] = Body([],
description="历史对话,设为一个整数可以从数据库中读取历史消息", description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[[ examples=[[
@ -54,7 +54,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks=callbacks, callbacks=callbacks,
) )
if conversation_id is None: if not conversation_id:
history = [History.from_data(h) for h in history] history = [History.from_data(h) for h in history]
prompt_template = get_prompt_template("llm_chat", prompt_name) prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False) input_msg = History(role="user", content=prompt_template).to_msg_template(False)

View File

@ -12,8 +12,8 @@ class AzureWorker(ApiModelWorker):
def __init__( def __init__(
self, self,
*, *,
controller_addr: str, controller_addr: str = None,
worker_addr: str, worker_addr: str = None,
model_names: List[str] = ["azure-api"], model_names: List[str] = ["azure-api"],
version: str = "gpt-35-turbo", version: str = "gpt-35-turbo",
**kwargs, **kwargs,
@ -60,6 +60,8 @@ class AzureWorker(ApiModelWorker):
"error_code": 0, "error_code": 0,
"text": text "text": text
} }
else:
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings

View File

@ -74,7 +74,7 @@ class BaiChuanWorker(ApiModelWorker):
"text": text "text": text
} }
else: else:
yield { data = {
"error_code": resp["code"], "error_code": resp["code"],
"text": resp["msg"], "text": resp["msg"],
"error": { "error": {
@ -84,6 +84,8 @@ class BaiChuanWorker(ApiModelWorker):
"code": None, "code": None,
} }
} }
self.logger.error(f"请求百川 API 时发生错误:{data}")
yield data
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings

View File

@ -15,10 +15,6 @@ from typing import Dict, List, Optional
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"] __all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
class ApiConfigParams(BaseModel): class ApiConfigParams(BaseModel):
''' '''
@ -110,7 +106,9 @@ class ApiModelWorker(BaseModelWorker):
controller_addr=controller_addr, controller_addr=controller_addr,
worker_addr=worker_addr, worker_addr=worker_addr,
**kwargs) **kwargs)
import fastchat.serve.base_model_worker
import sys import sys
self.logger = fastchat.serve.base_model_worker.logger
# 恢复被fastchat覆盖的标准输出 # 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__ sys.stderr = sys.__stderr__

View File

@ -48,23 +48,33 @@ class FangZhouWorker(ApiModelWorker):
text = "" text = ""
if log_verbose: if log_verbose:
logger.info(f'{self.__class__.__name__}:maas: {maas}') self.logger.info(f'{self.__class__.__name__}:maas: {maas}')
for resp in maas.stream_chat(req): for resp in maas.stream_chat(req):
error = resp.error if error := resp.error:
if error.code_n > 0: if error.code_n > 0:
yield { data = {
"error_code": error.code_n, "error_code": error.code_n,
"text": error.message, "text": error.message,
"error": { "error": {
"message": error.message, "message": error.message,
"type": "invalid_request_error", "type": "invalid_request_error",
"param": None, "param": None,
"code": None, "code": None,
} }
} }
elif chunk := resp.choice.message.content: self.logger.error(f"请求方舟 API 时发生错误:{data}")
text += chunk yield data
yield {"error_code": 0, "text": text} elif chunk := resp.choice.message.content:
text += chunk
yield {"error_code": 0, "text": text}
else:
data = {
"error_code": 500,
"text": f"请求方舟 API 时发生未知的错误: {resp}"
}
self.logger.error(data)
yield data
break
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings

View File

@ -74,7 +74,7 @@ class MiniMaxWorker(ApiModelWorker):
text = "" text = ""
for e in r.iter_text(): for e in r.iter_text():
if not e.startswith("data: "): # 真是优秀的返回 if not e.startswith("data: "): # 真是优秀的返回
yield { data = {
"error_code": 500, "error_code": 500,
"text": f"minimax返回错误的结果{e}", "text": f"minimax返回错误的结果{e}",
"error": { "error": {
@ -84,6 +84,8 @@ class MiniMaxWorker(ApiModelWorker):
"code": None, "code": None,
} }
} }
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
yield data
continue continue
data = json.loads(e[6:]) data = json.loads(e[6:])
@ -123,7 +125,7 @@ class MiniMaxWorker(ApiModelWorker):
if embeddings := r.get("vectors"): if embeddings := r.get("vectors"):
result += embeddings result += embeddings
elif error := r.get("base_resp"): elif error := r.get("base_resp"):
return { data = {
"code": error["status_code"], "code": error["status_code"],
"msg": error["status_msg"], "msg": error["status_msg"],
"error": { "error": {
@ -133,6 +135,8 @@ class MiniMaxWorker(ApiModelWorker):
"code": None, "code": None,
} }
} }
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
return data
i += 10 i += 10
return {"code": 200, "data": embeddings} return {"code": 200, "data": embeddings}

View File

@ -154,7 +154,7 @@ class QianFanWorker(ApiModelWorker):
"text": text "text": text
} }
else: else:
yield { data = {
"error_code": resp["error_code"], "error_code": resp["error_code"],
"text": resp["error_msg"], "text": resp["error_msg"],
"error": { "error": {
@ -164,6 +164,8 @@ class QianFanWorker(ApiModelWorker):
"code": None, "code": None,
} }
} }
self.logger.error(f"请求千帆 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
params.load_config(self.model_names[0]) params.load_config(self.model_names[0])
@ -189,7 +191,7 @@ class QianFanWorker(ApiModelWorker):
for texts in params.texts[i:i+10]: for texts in params.texts[i:i+10]:
resp = client.post(url, json={"input": texts}).json() resp = client.post(url, json={"input": texts}).json()
if "error_cdoe" in resp: if "error_cdoe" in resp:
return { data = {
"code": resp["error_code"], "code": resp["error_code"],
"msg": resp["error_msg"], "msg": resp["error_msg"],
"error": { "error": {
@ -199,6 +201,8 @@ class QianFanWorker(ApiModelWorker):
"code": None, "code": None,
} }
} }
self.logger.error(f"请求千帆 API 时发生错误:{data}")
return data
else: else:
embeddings = [x["embedding"] for x in resp.get("data", [])] embeddings = [x["embedding"] for x in resp.get("data", [])]
result += embeddings result += embeddings

View File

@ -53,7 +53,7 @@ class QwenWorker(ApiModelWorker):
"text": choices[0]["message"]["content"], "text": choices[0]["message"]["content"],
} }
else: else:
yield { data = {
"error_code": resp["status_code"], "error_code": resp["status_code"],
"text": resp["message"], "text": resp["message"],
"error": { "error": {
@ -63,7 +63,8 @@ class QwenWorker(ApiModelWorker):
"code": None, "code": None,
} }
} }
self.logger.error(f"请求千问 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import dashscope import dashscope
@ -80,7 +81,7 @@ class QwenWorker(ApiModelWorker):
api_key=params.api_key, api_key=params.api_key,
) )
if resp["status_code"] != 200: if resp["status_code"] != 200:
return { data = {
"code": resp["status_code"], "code": resp["status_code"],
"msg": resp.message, "msg": resp.message,
"error": { "error": {
@ -90,6 +91,8 @@ class QwenWorker(ApiModelWorker):
"code": None, "code": None,
} }
} }
self.logger.error(f"请求千问 API 时发生错误:{data}")
return data
else: else:
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]] embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
result += embeddings result += embeddings

View File

@ -6,7 +6,6 @@ from fastchat.conversation import Conversation
from server.model_workers.base import * from server.model_workers.base import *
from server.utils import get_httpx_client from server.utils import get_httpx_client
from fastchat import conversation as conv from fastchat import conversation as conv
import sys
import json import json
from typing import List, Literal, Dict from typing import List, Literal, Dict
import requests import requests
@ -64,12 +63,12 @@ class TianGongWorker(ApiModelWorker):
"text": text "text": text
} }
else: else:
yield { data = {
"error_code": resp["code"], "error_code": resp["code"],
"text": resp["code_msg"] "text": resp["code_msg"]
} }
self.logger.error(f"请求天工 API 时出错:{data}")
yield data
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings

View File

@ -62,7 +62,7 @@ class XingHuoWorker(ApiModelWorker):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
except: except:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
params.max_tokens = min(details["max_tokens"], params.max_tokens) params.max_tokens = min(details["max_tokens"], params.max_tokens or 0)
for chunk in iter_over_async( for chunk in iter_over_async(
request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages, request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages,
params.temperature, params.max_tokens), params.temperature, params.max_tokens),

View File

@ -44,7 +44,7 @@ class ChatGLMWorker(ApiModelWorker):
if e.event == "add": if e.event == "add":
yield {"error_code": 0, "text": e.data} yield {"error_code": 0, "text": e.data}
elif e.event in ["error", "interrupted"]: elif e.event in ["error", "interrupted"]:
yield { data = {
"error_code": 500, "error_code": 500,
"text": str(e), "text": str(e),
"error": { "error": {
@ -54,6 +54,8 @@ class ChatGLMWorker(ApiModelWorker):
"code": None, "code": None,
} }
} }
self.logger.error(f"请求智谱 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import zhipuai import zhipuai
@ -68,9 +70,12 @@ class ChatGLMWorker(ApiModelWorker):
if response["code"] == 200: if response["code"] == 200:
embeddings.append(response["data"]["embedding"]) embeddings.append(response["data"]["embedding"])
else: else:
self.logger.error(f"请求智谱 API 时发生错误:{response}")
return response # dict with code & msg return response # dict with code & msg
except Exception as e: except Exception as e:
return {"code": 500, "msg": f"对文本向量化时出错:{e}"} self.logger.error(f"请求智谱 API 时发生错误:{data}")
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
return data
return {"code": 200, "data": embeddings} return {"code": 200, "data": embeddings}

View File

@ -16,7 +16,7 @@ for x in list_config_llm_models()["online"]:
workers.append(x) workers.append(x)
print(f"all workers to test: {workers}") print(f"all workers to test: {workers}")
# workers = ["qianfan-api"] # workers = ["fangzhou-api"]
@pytest.mark.parametrize("worker", workers) @pytest.mark.parametrize("worker", workers)
@ -28,11 +28,11 @@ def test_chat(worker):
) )
print(f"\nchat with {worker} \n") print(f"\nchat with {worker} \n")
worker_class = get_model_worker_config(worker)["worker_class"] if worker_class := get_model_worker_config(worker).get("worker_class"):
for x in worker_class().do_chat(params): for x in worker_class().do_chat(params):
pprint(x) pprint(x)
assert isinstance(x, dict) assert isinstance(x, dict)
assert x["error_code"] == 0 assert x["error_code"] == 0
@pytest.mark.parametrize("worker", workers) @pytest.mark.parametrize("worker", workers)
@ -44,19 +44,19 @@ def test_embeddings(worker):
] ]
) )
worker_class = get_model_worker_config(worker)["worker_class"] if worker_class := get_model_worker_config(worker).get("worker_class"):
if worker_class.can_embedding(): if worker_class.can_embedding():
print(f"\embeddings with {worker} \n") print(f"\embeddings with {worker} \n")
resp = worker_class().do_embeddings(params) resp = worker_class().do_embeddings(params)
pprint(resp, depth=2) pprint(resp, depth=2)
assert resp["code"] == 200 assert resp["code"] == 200
assert "data" in resp assert "data" in resp
embeddings = resp["data"] embeddings = resp["data"]
assert isinstance(embeddings, list) and len(embeddings) > 0 assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0 assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float) assert isinstance(embeddings[0][0], float)
print("向量长度:", len(embeddings[0])) print("向量长度:", len(embeddings[0]))
# @pytest.mark.parametrize("worker", workers) # @pytest.mark.parametrize("worker", workers)