给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。 (#2169)
* 给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。 /chat/chat 接口 conversation_id参数改为默认 "",避免 swagger 页面默认值错误导致历史消息失效 * 修复在线模型一些bug
This commit is contained in:
parent
1b0cf67a57
commit
1de4258aa0
|
|
@ -18,7 +18,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa
|
|||
|
||||
|
||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
conversation_id: str = Body(None, description="对话框ID"),
|
||||
conversation_id: str = Body("", description="对话框ID"),
|
||||
history: Union[int, List[History]] = Body([],
|
||||
description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
||||
examples=[[
|
||||
|
|
@ -54,7 +54,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if conversation_id is None:
|
||||
if not conversation_id:
|
||||
history = [History.from_data(h) for h in history]
|
||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ class AzureWorker(ApiModelWorker):
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
model_names: List[str] = ["azure-api"],
|
||||
version: str = "gpt-35-turbo",
|
||||
**kwargs,
|
||||
|
|
@ -60,6 +60,8 @@ class AzureWorker(ApiModelWorker):
|
|||
"error_code": 0,
|
||||
"text": text
|
||||
}
|
||||
else:
|
||||
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class BaiChuanWorker(ApiModelWorker):
|
|||
"text": text
|
||||
}
|
||||
else:
|
||||
yield {
|
||||
data = {
|
||||
"error_code": resp["code"],
|
||||
"text": resp["msg"],
|
||||
"error": {
|
||||
|
|
@ -84,6 +84,8 @@ class BaiChuanWorker(ApiModelWorker):
|
|||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求百川 API 时发生错误:{data}")
|
||||
yield data
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
|
|||
|
|
@ -15,10 +15,6 @@ from typing import Dict, List, Optional
|
|||
|
||||
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
|
||||
|
||||
# 恢复被fastchat覆盖的标准输出
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
|
||||
class ApiConfigParams(BaseModel):
|
||||
'''
|
||||
|
|
@ -110,7 +106,9 @@ class ApiModelWorker(BaseModelWorker):
|
|||
controller_addr=controller_addr,
|
||||
worker_addr=worker_addr,
|
||||
**kwargs)
|
||||
import fastchat.serve.base_model_worker
|
||||
import sys
|
||||
self.logger = fastchat.serve.base_model_worker.logger
|
||||
# 恢复被fastchat覆盖的标准输出
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
|
|
|||
|
|
@ -48,23 +48,33 @@ class FangZhouWorker(ApiModelWorker):
|
|||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:maas: {maas}')
|
||||
self.logger.info(f'{self.__class__.__name__}:maas: {maas}')
|
||||
for resp in maas.stream_chat(req):
|
||||
error = resp.error
|
||||
if error.code_n > 0:
|
||||
yield {
|
||||
"error_code": error.code_n,
|
||||
"text": error.message,
|
||||
"error": {
|
||||
"message": error.message,
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
elif chunk := resp.choice.message.content:
|
||||
text += chunk
|
||||
yield {"error_code": 0, "text": text}
|
||||
if error := resp.error:
|
||||
if error.code_n > 0:
|
||||
data = {
|
||||
"error_code": error.code_n,
|
||||
"text": error.message,
|
||||
"error": {
|
||||
"message": error.message,
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求方舟 API 时发生错误:{data}")
|
||||
yield data
|
||||
elif chunk := resp.choice.message.content:
|
||||
text += chunk
|
||||
yield {"error_code": 0, "text": text}
|
||||
else:
|
||||
data = {
|
||||
"error_code": 500,
|
||||
"text": f"请求方舟 API 时发生未知的错误: {resp}"
|
||||
}
|
||||
self.logger.error(data)
|
||||
yield data
|
||||
break
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class MiniMaxWorker(ApiModelWorker):
|
|||
text = ""
|
||||
for e in r.iter_text():
|
||||
if not e.startswith("data: "): # 真是优秀的返回
|
||||
yield {
|
||||
data = {
|
||||
"error_code": 500,
|
||||
"text": f"minimax返回错误的结果:{e}",
|
||||
"error": {
|
||||
|
|
@ -84,6 +84,8 @@ class MiniMaxWorker(ApiModelWorker):
|
|||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
||||
yield data
|
||||
continue
|
||||
|
||||
data = json.loads(e[6:])
|
||||
|
|
@ -123,7 +125,7 @@ class MiniMaxWorker(ApiModelWorker):
|
|||
if embeddings := r.get("vectors"):
|
||||
result += embeddings
|
||||
elif error := r.get("base_resp"):
|
||||
return {
|
||||
data = {
|
||||
"code": error["status_code"],
|
||||
"msg": error["status_msg"],
|
||||
"error": {
|
||||
|
|
@ -133,6 +135,8 @@ class MiniMaxWorker(ApiModelWorker):
|
|||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
||||
return data
|
||||
i += 10
|
||||
return {"code": 200, "data": embeddings}
|
||||
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class QianFanWorker(ApiModelWorker):
|
|||
"text": text
|
||||
}
|
||||
else:
|
||||
yield {
|
||||
data = {
|
||||
"error_code": resp["error_code"],
|
||||
"text": resp["error_msg"],
|
||||
"error": {
|
||||
|
|
@ -164,6 +164,8 @@ class QianFanWorker(ApiModelWorker):
|
|||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
||||
yield data
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
|
|
@ -189,7 +191,7 @@ class QianFanWorker(ApiModelWorker):
|
|||
for texts in params.texts[i:i+10]:
|
||||
resp = client.post(url, json={"input": texts}).json()
|
||||
if "error_cdoe" in resp:
|
||||
return {
|
||||
data = {
|
||||
"code": resp["error_code"],
|
||||
"msg": resp["error_msg"],
|
||||
"error": {
|
||||
|
|
@ -199,6 +201,8 @@ class QianFanWorker(ApiModelWorker):
|
|||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
||||
return data
|
||||
else:
|
||||
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
||||
result += embeddings
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class QwenWorker(ApiModelWorker):
|
|||
"text": choices[0]["message"]["content"],
|
||||
}
|
||||
else:
|
||||
yield {
|
||||
data = {
|
||||
"error_code": resp["status_code"],
|
||||
"text": resp["message"],
|
||||
"error": {
|
||||
|
|
@ -63,7 +63,8 @@ class QwenWorker(ApiModelWorker):
|
|||
"code": None,
|
||||
}
|
||||
}
|
||||
|
||||
self.logger.error(f"请求千问 API 时发生错误:{data}")
|
||||
yield data
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
import dashscope
|
||||
|
|
@ -80,7 +81,7 @@ class QwenWorker(ApiModelWorker):
|
|||
api_key=params.api_key,
|
||||
)
|
||||
if resp["status_code"] != 200:
|
||||
return {
|
||||
data = {
|
||||
"code": resp["status_code"],
|
||||
"msg": resp.message,
|
||||
"error": {
|
||||
|
|
@ -90,6 +91,8 @@ class QwenWorker(ApiModelWorker):
|
|||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求千问 API 时发生错误:{data}")
|
||||
return data
|
||||
else:
|
||||
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
|
||||
result += embeddings
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from fastchat.conversation import Conversation
|
|||
from server.model_workers.base import *
|
||||
from server.utils import get_httpx_client
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Literal, Dict
|
||||
import requests
|
||||
|
|
@ -64,12 +63,12 @@ class TianGongWorker(ApiModelWorker):
|
|||
"text": text
|
||||
}
|
||||
else:
|
||||
yield {
|
||||
data = {
|
||||
"error_code": resp["code"],
|
||||
"text": resp["code_msg"]
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
self.logger.error(f"请求天工 API 时出错:{data}")
|
||||
yield data
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class XingHuoWorker(ApiModelWorker):
|
|||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
params.max_tokens = min(details["max_tokens"], params.max_tokens)
|
||||
params.max_tokens = min(details["max_tokens"], params.max_tokens or 0)
|
||||
for chunk in iter_over_async(
|
||||
request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages,
|
||||
params.temperature, params.max_tokens),
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class ChatGLMWorker(ApiModelWorker):
|
|||
if e.event == "add":
|
||||
yield {"error_code": 0, "text": e.data}
|
||||
elif e.event in ["error", "interrupted"]:
|
||||
yield {
|
||||
data = {
|
||||
"error_code": 500,
|
||||
"text": str(e),
|
||||
"error": {
|
||||
|
|
@ -54,6 +54,8 @@ class ChatGLMWorker(ApiModelWorker):
|
|||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求智谱 API 时发生错误:{data}")
|
||||
yield data
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
import zhipuai
|
||||
|
|
@ -68,9 +70,12 @@ class ChatGLMWorker(ApiModelWorker):
|
|||
if response["code"] == 200:
|
||||
embeddings.append(response["data"]["embedding"])
|
||||
else:
|
||||
self.logger.error(f"请求智谱 API 时发生错误:{response}")
|
||||
return response # dict with code & msg
|
||||
except Exception as e:
|
||||
return {"code": 500, "msg": f"对文本向量化时出错:{e}"}
|
||||
self.logger.error(f"请求智谱 API 时发生错误:{data}")
|
||||
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
|
||||
return data
|
||||
|
||||
return {"code": 200, "data": embeddings}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ for x in list_config_llm_models()["online"]:
|
|||
workers.append(x)
|
||||
print(f"all workers to test: {workers}")
|
||||
|
||||
# workers = ["qianfan-api"]
|
||||
# workers = ["fangzhou-api"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("worker", workers)
|
||||
|
|
@ -28,11 +28,11 @@ def test_chat(worker):
|
|||
)
|
||||
print(f"\nchat with {worker} \n")
|
||||
|
||||
worker_class = get_model_worker_config(worker)["worker_class"]
|
||||
for x in worker_class().do_chat(params):
|
||||
pprint(x)
|
||||
assert isinstance(x, dict)
|
||||
assert x["error_code"] == 0
|
||||
if worker_class := get_model_worker_config(worker).get("worker_class"):
|
||||
for x in worker_class().do_chat(params):
|
||||
pprint(x)
|
||||
assert isinstance(x, dict)
|
||||
assert x["error_code"] == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("worker", workers)
|
||||
|
|
@ -44,19 +44,19 @@ def test_embeddings(worker):
|
|||
]
|
||||
)
|
||||
|
||||
worker_class = get_model_worker_config(worker)["worker_class"]
|
||||
if worker_class.can_embedding():
|
||||
print(f"\embeddings with {worker} \n")
|
||||
resp = worker_class().do_embeddings(params)
|
||||
if worker_class := get_model_worker_config(worker).get("worker_class"):
|
||||
if worker_class.can_embedding():
|
||||
print(f"\embeddings with {worker} \n")
|
||||
resp = worker_class().do_embeddings(params)
|
||||
|
||||
pprint(resp, depth=2)
|
||||
assert resp["code"] == 200
|
||||
assert "data" in resp
|
||||
embeddings = resp["data"]
|
||||
assert isinstance(embeddings, list) and len(embeddings) > 0
|
||||
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
||||
assert isinstance(embeddings[0][0], float)
|
||||
print("向量长度:", len(embeddings[0]))
|
||||
pprint(resp, depth=2)
|
||||
assert resp["code"] == 200
|
||||
assert "data" in resp
|
||||
embeddings = resp["data"]
|
||||
assert isinstance(embeddings, list) and len(embeddings) > 0
|
||||
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
||||
assert isinstance(embeddings[0][0], float)
|
||||
print("向量长度:", len(embeddings[0]))
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("worker", workers)
|
||||
|
|
|
|||
Loading…
Reference in New Issue