给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。 (#2169)

* 给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。
/chat/chat 接口 conversation_id参数改为默认 "",避免 swagger 页面默认值错误导致历史消息失效

* 修复在线模型一些bug
This commit is contained in:
liunux4odoo 2023-11-25 13:51:07 +08:00 committed by GitHub
parent 1b0cf67a57
commit 1de4258aa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 85 additions and 58 deletions

View File

@ -18,7 +18,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body(None, description="对话框ID"),
conversation_id: str = Body("", description="对话框ID"),
history: Union[int, List[History]] = Body([],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[[
@ -54,7 +54,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks=callbacks,
)
if conversation_id is None:
if not conversation_id:
history = [History.from_data(h) for h in history]
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)

View File

@ -12,8 +12,8 @@ class AzureWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str,
worker_addr: str,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["azure-api"],
version: str = "gpt-35-turbo",
**kwargs,
@ -60,6 +60,8 @@ class AzureWorker(ApiModelWorker):
"error_code": 0,
"text": text
}
else:
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
def get_embeddings(self, params):
# TODO: 支持embeddings

View File

@ -74,7 +74,7 @@ class BaiChuanWorker(ApiModelWorker):
"text": text
}
else:
yield {
data = {
"error_code": resp["code"],
"text": resp["msg"],
"error": {
@ -84,6 +84,8 @@ class BaiChuanWorker(ApiModelWorker):
"code": None,
}
}
self.logger.error(f"请求百川 API 时发生错误:{data}")
yield data
def get_embeddings(self, params):
# TODO: 支持embeddings

View File

@ -15,10 +15,6 @@ from typing import Dict, List, Optional
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
class ApiConfigParams(BaseModel):
'''
@ -110,7 +106,9 @@ class ApiModelWorker(BaseModelWorker):
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
import fastchat.serve.base_model_worker
import sys
self.logger = fastchat.serve.base_model_worker.logger
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__

View File

@ -48,23 +48,33 @@ class FangZhouWorker(ApiModelWorker):
text = ""
if log_verbose:
logger.info(f'{self.__class__.__name__}:maas: {maas}')
self.logger.info(f'{self.__class__.__name__}:maas: {maas}')
for resp in maas.stream_chat(req):
error = resp.error
if error.code_n > 0:
yield {
"error_code": error.code_n,
"text": error.message,
"error": {
"message": error.message,
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
elif chunk := resp.choice.message.content:
text += chunk
yield {"error_code": 0, "text": text}
if error := resp.error:
if error.code_n > 0:
data = {
"error_code": error.code_n,
"text": error.message,
"error": {
"message": error.message,
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求方舟 API 时发生错误:{data}")
yield data
elif chunk := resp.choice.message.content:
text += chunk
yield {"error_code": 0, "text": text}
else:
data = {
"error_code": 500,
"text": f"请求方舟 API 时发生未知的错误: {resp}"
}
self.logger.error(data)
yield data
break
def get_embeddings(self, params):
# TODO: 支持embeddings

View File

@ -74,7 +74,7 @@ class MiniMaxWorker(ApiModelWorker):
text = ""
for e in r.iter_text():
if not e.startswith("data: "): # 真是优秀的返回
yield {
data = {
"error_code": 500,
"text": f"minimax返回错误的结果{e}",
"error": {
@ -84,6 +84,8 @@ class MiniMaxWorker(ApiModelWorker):
"code": None,
}
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
yield data
continue
data = json.loads(e[6:])
@ -123,7 +125,7 @@ class MiniMaxWorker(ApiModelWorker):
if embeddings := r.get("vectors"):
result += embeddings
elif error := r.get("base_resp"):
return {
data = {
"code": error["status_code"],
"msg": error["status_msg"],
"error": {
@ -133,6 +135,8 @@ class MiniMaxWorker(ApiModelWorker):
"code": None,
}
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
return data
i += 10
return {"code": 200, "data": embeddings}

View File

@ -154,7 +154,7 @@ class QianFanWorker(ApiModelWorker):
"text": text
}
else:
yield {
data = {
"error_code": resp["error_code"],
"text": resp["error_msg"],
"error": {
@ -164,6 +164,8 @@ class QianFanWorker(ApiModelWorker):
"code": None,
}
}
self.logger.error(f"请求千帆 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
params.load_config(self.model_names[0])
@ -189,7 +191,7 @@ class QianFanWorker(ApiModelWorker):
for texts in params.texts[i:i+10]:
resp = client.post(url, json={"input": texts}).json()
if "error_cdoe" in resp:
return {
data = {
"code": resp["error_code"],
"msg": resp["error_msg"],
"error": {
@ -199,6 +201,8 @@ class QianFanWorker(ApiModelWorker):
"code": None,
}
}
self.logger.error(f"请求千帆 API 时发生错误:{data}")
return data
else:
embeddings = [x["embedding"] for x in resp.get("data", [])]
result += embeddings

View File

@ -53,7 +53,7 @@ class QwenWorker(ApiModelWorker):
"text": choices[0]["message"]["content"],
}
else:
yield {
data = {
"error_code": resp["status_code"],
"text": resp["message"],
"error": {
@ -63,7 +63,8 @@ class QwenWorker(ApiModelWorker):
"code": None,
}
}
self.logger.error(f"请求千问 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import dashscope
@ -80,7 +81,7 @@ class QwenWorker(ApiModelWorker):
api_key=params.api_key,
)
if resp["status_code"] != 200:
return {
data = {
"code": resp["status_code"],
"msg": resp.message,
"error": {
@ -90,6 +91,8 @@ class QwenWorker(ApiModelWorker):
"code": None,
}
}
self.logger.error(f"请求千问 API 时发生错误:{data}")
return data
else:
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
result += embeddings

View File

@ -6,7 +6,6 @@ from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
from fastchat import conversation as conv
import sys
import json
from typing import List, Literal, Dict
import requests
@ -64,12 +63,12 @@ class TianGongWorker(ApiModelWorker):
"text": text
}
else:
yield {
data = {
"error_code": resp["code"],
"text": resp["code_msg"]
}
}
self.logger.error(f"请求天工 API 时出错:{data}")
yield data
def get_embeddings(self, params):
# TODO: 支持embeddings

View File

@ -62,7 +62,7 @@ class XingHuoWorker(ApiModelWorker):
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
params.max_tokens = min(details["max_tokens"], params.max_tokens)
params.max_tokens = min(details["max_tokens"], params.max_tokens or 0)
for chunk in iter_over_async(
request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages,
params.temperature, params.max_tokens),

View File

@ -44,7 +44,7 @@ class ChatGLMWorker(ApiModelWorker):
if e.event == "add":
yield {"error_code": 0, "text": e.data}
elif e.event in ["error", "interrupted"]:
yield {
data = {
"error_code": 500,
"text": str(e),
"error": {
@ -54,6 +54,8 @@ class ChatGLMWorker(ApiModelWorker):
"code": None,
}
}
self.logger.error(f"请求智谱 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import zhipuai
@ -68,9 +70,12 @@ class ChatGLMWorker(ApiModelWorker):
if response["code"] == 200:
embeddings.append(response["data"]["embedding"])
else:
self.logger.error(f"请求智谱 API 时发生错误:{response}")
return response # dict with code & msg
except Exception as e:
return {"code": 500, "msg": f"对文本向量化时出错:{e}"}
self.logger.error(f"请求智谱 API 时发生错误:{data}")
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
return data
return {"code": 200, "data": embeddings}

View File

@ -16,7 +16,7 @@ for x in list_config_llm_models()["online"]:
workers.append(x)
print(f"all workers to test: {workers}")
# workers = ["qianfan-api"]
# workers = ["fangzhou-api"]
@pytest.mark.parametrize("worker", workers)
@ -28,11 +28,11 @@ def test_chat(worker):
)
print(f"\nchat with {worker} \n")
worker_class = get_model_worker_config(worker)["worker_class"]
for x in worker_class().do_chat(params):
pprint(x)
assert isinstance(x, dict)
assert x["error_code"] == 0
if worker_class := get_model_worker_config(worker).get("worker_class"):
for x in worker_class().do_chat(params):
pprint(x)
assert isinstance(x, dict)
assert x["error_code"] == 0
@pytest.mark.parametrize("worker", workers)
@ -44,19 +44,19 @@ def test_embeddings(worker):
]
)
worker_class = get_model_worker_config(worker)["worker_class"]
if worker_class.can_embedding():
print(f"\embeddings with {worker} \n")
resp = worker_class().do_embeddings(params)
if worker_class := get_model_worker_config(worker).get("worker_class"):
if worker_class.can_embedding():
print(f"\embeddings with {worker} \n")
resp = worker_class().do_embeddings(params)
pprint(resp, depth=2)
assert resp["code"] == 200
assert "data" in resp
embeddings = resp["data"]
assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float)
print("向量长度:", len(embeddings[0]))
pprint(resp, depth=2)
assert resp["code"] == 200
assert "data" in resp
embeddings = resp["data"]
assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float)
print("向量长度:", len(embeddings[0]))
# @pytest.mark.parametrize("worker", workers)