给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。 (#2169)
* 给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。 /chat/chat 接口 conversation_id参数改为默认 "",避免 swagger 页面默认值错误导致历史消息失效 * 修复在线模型一些bug
This commit is contained in:
parent
1b0cf67a57
commit
1de4258aa0
|
|
@ -18,7 +18,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa
|
||||||
|
|
||||||
|
|
||||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
conversation_id: str = Body(None, description="对话框ID"),
|
conversation_id: str = Body("", description="对话框ID"),
|
||||||
history: Union[int, List[History]] = Body([],
|
history: Union[int, List[History]] = Body([],
|
||||||
description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
||||||
examples=[[
|
examples=[[
|
||||||
|
|
@ -54,7 +54,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
if conversation_id is None:
|
if not conversation_id:
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,8 @@ class AzureWorker(ApiModelWorker):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
controller_addr: str,
|
controller_addr: str = None,
|
||||||
worker_addr: str,
|
worker_addr: str = None,
|
||||||
model_names: List[str] = ["azure-api"],
|
model_names: List[str] = ["azure-api"],
|
||||||
version: str = "gpt-35-turbo",
|
version: str = "gpt-35-turbo",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
@ -60,6 +60,8 @@ class AzureWorker(ApiModelWorker):
|
||||||
"error_code": 0,
|
"error_code": 0,
|
||||||
"text": text
|
"text": text
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
# TODO: 支持embeddings
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ class BaiChuanWorker(ApiModelWorker):
|
||||||
"text": text
|
"text": text
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
yield {
|
data = {
|
||||||
"error_code": resp["code"],
|
"error_code": resp["code"],
|
||||||
"text": resp["msg"],
|
"text": resp["msg"],
|
||||||
"error": {
|
"error": {
|
||||||
|
|
@ -84,6 +84,8 @@ class BaiChuanWorker(ApiModelWorker):
|
||||||
"code": None,
|
"code": None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
self.logger.error(f"请求百川 API 时发生错误:{data}")
|
||||||
|
yield data
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
# TODO: 支持embeddings
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,6 @@ from typing import Dict, List, Optional
|
||||||
|
|
||||||
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
|
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
|
||||||
|
|
||||||
# 恢复被fastchat覆盖的标准输出
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
sys.stderr = sys.__stderr__
|
|
||||||
|
|
||||||
|
|
||||||
class ApiConfigParams(BaseModel):
|
class ApiConfigParams(BaseModel):
|
||||||
'''
|
'''
|
||||||
|
|
@ -110,7 +106,9 @@ class ApiModelWorker(BaseModelWorker):
|
||||||
controller_addr=controller_addr,
|
controller_addr=controller_addr,
|
||||||
worker_addr=worker_addr,
|
worker_addr=worker_addr,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
import fastchat.serve.base_model_worker
|
||||||
import sys
|
import sys
|
||||||
|
self.logger = fastchat.serve.base_model_worker.logger
|
||||||
# 恢复被fastchat覆盖的标准输出
|
# 恢复被fastchat覆盖的标准输出
|
||||||
sys.stdout = sys.__stdout__
|
sys.stdout = sys.__stdout__
|
||||||
sys.stderr = sys.__stderr__
|
sys.stderr = sys.__stderr__
|
||||||
|
|
|
||||||
|
|
@ -48,23 +48,33 @@ class FangZhouWorker(ApiModelWorker):
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
if log_verbose:
|
if log_verbose:
|
||||||
logger.info(f'{self.__class__.__name__}:maas: {maas}')
|
self.logger.info(f'{self.__class__.__name__}:maas: {maas}')
|
||||||
for resp in maas.stream_chat(req):
|
for resp in maas.stream_chat(req):
|
||||||
error = resp.error
|
if error := resp.error:
|
||||||
if error.code_n > 0:
|
if error.code_n > 0:
|
||||||
yield {
|
data = {
|
||||||
"error_code": error.code_n,
|
"error_code": error.code_n,
|
||||||
"text": error.message,
|
"text": error.message,
|
||||||
"error": {
|
"error": {
|
||||||
"message": error.message,
|
"message": error.message,
|
||||||
"type": "invalid_request_error",
|
"type": "invalid_request_error",
|
||||||
"param": None,
|
"param": None,
|
||||||
"code": None,
|
"code": None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
elif chunk := resp.choice.message.content:
|
self.logger.error(f"请求方舟 API 时发生错误:{data}")
|
||||||
text += chunk
|
yield data
|
||||||
yield {"error_code": 0, "text": text}
|
elif chunk := resp.choice.message.content:
|
||||||
|
text += chunk
|
||||||
|
yield {"error_code": 0, "text": text}
|
||||||
|
else:
|
||||||
|
data = {
|
||||||
|
"error_code": 500,
|
||||||
|
"text": f"请求方舟 API 时发生未知的错误: {resp}"
|
||||||
|
}
|
||||||
|
self.logger.error(data)
|
||||||
|
yield data
|
||||||
|
break
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
# TODO: 支持embeddings
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ class MiniMaxWorker(ApiModelWorker):
|
||||||
text = ""
|
text = ""
|
||||||
for e in r.iter_text():
|
for e in r.iter_text():
|
||||||
if not e.startswith("data: "): # 真是优秀的返回
|
if not e.startswith("data: "): # 真是优秀的返回
|
||||||
yield {
|
data = {
|
||||||
"error_code": 500,
|
"error_code": 500,
|
||||||
"text": f"minimax返回错误的结果:{e}",
|
"text": f"minimax返回错误的结果:{e}",
|
||||||
"error": {
|
"error": {
|
||||||
|
|
@ -84,6 +84,8 @@ class MiniMaxWorker(ApiModelWorker):
|
||||||
"code": None,
|
"code": None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
||||||
|
yield data
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data = json.loads(e[6:])
|
data = json.loads(e[6:])
|
||||||
|
|
@ -123,7 +125,7 @@ class MiniMaxWorker(ApiModelWorker):
|
||||||
if embeddings := r.get("vectors"):
|
if embeddings := r.get("vectors"):
|
||||||
result += embeddings
|
result += embeddings
|
||||||
elif error := r.get("base_resp"):
|
elif error := r.get("base_resp"):
|
||||||
return {
|
data = {
|
||||||
"code": error["status_code"],
|
"code": error["status_code"],
|
||||||
"msg": error["status_msg"],
|
"msg": error["status_msg"],
|
||||||
"error": {
|
"error": {
|
||||||
|
|
@ -133,6 +135,8 @@ class MiniMaxWorker(ApiModelWorker):
|
||||||
"code": None,
|
"code": None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
||||||
|
return data
|
||||||
i += 10
|
i += 10
|
||||||
return {"code": 200, "data": embeddings}
|
return {"code": 200, "data": embeddings}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -154,7 +154,7 @@ class QianFanWorker(ApiModelWorker):
|
||||||
"text": text
|
"text": text
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
yield {
|
data = {
|
||||||
"error_code": resp["error_code"],
|
"error_code": resp["error_code"],
|
||||||
"text": resp["error_msg"],
|
"text": resp["error_msg"],
|
||||||
"error": {
|
"error": {
|
||||||
|
|
@ -164,6 +164,8 @@ class QianFanWorker(ApiModelWorker):
|
||||||
"code": None,
|
"code": None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
||||||
|
yield data
|
||||||
|
|
||||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||||
params.load_config(self.model_names[0])
|
params.load_config(self.model_names[0])
|
||||||
|
|
@ -189,7 +191,7 @@ class QianFanWorker(ApiModelWorker):
|
||||||
for texts in params.texts[i:i+10]:
|
for texts in params.texts[i:i+10]:
|
||||||
resp = client.post(url, json={"input": texts}).json()
|
resp = client.post(url, json={"input": texts}).json()
|
||||||
if "error_cdoe" in resp:
|
if "error_cdoe" in resp:
|
||||||
return {
|
data = {
|
||||||
"code": resp["error_code"],
|
"code": resp["error_code"],
|
||||||
"msg": resp["error_msg"],
|
"msg": resp["error_msg"],
|
||||||
"error": {
|
"error": {
|
||||||
|
|
@ -199,6 +201,8 @@ class QianFanWorker(ApiModelWorker):
|
||||||
"code": None,
|
"code": None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
||||||
|
return data
|
||||||
else:
|
else:
|
||||||
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
||||||
result += embeddings
|
result += embeddings
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ class QwenWorker(ApiModelWorker):
|
||||||
"text": choices[0]["message"]["content"],
|
"text": choices[0]["message"]["content"],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
yield {
|
data = {
|
||||||
"error_code": resp["status_code"],
|
"error_code": resp["status_code"],
|
||||||
"text": resp["message"],
|
"text": resp["message"],
|
||||||
"error": {
|
"error": {
|
||||||
|
|
@ -63,7 +63,8 @@ class QwenWorker(ApiModelWorker):
|
||||||
"code": None,
|
"code": None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
self.logger.error(f"请求千问 API 时发生错误:{data}")
|
||||||
|
yield data
|
||||||
|
|
||||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||||
import dashscope
|
import dashscope
|
||||||
|
|
@ -80,7 +81,7 @@ class QwenWorker(ApiModelWorker):
|
||||||
api_key=params.api_key,
|
api_key=params.api_key,
|
||||||
)
|
)
|
||||||
if resp["status_code"] != 200:
|
if resp["status_code"] != 200:
|
||||||
return {
|
data = {
|
||||||
"code": resp["status_code"],
|
"code": resp["status_code"],
|
||||||
"msg": resp.message,
|
"msg": resp.message,
|
||||||
"error": {
|
"error": {
|
||||||
|
|
@ -90,6 +91,8 @@ class QwenWorker(ApiModelWorker):
|
||||||
"code": None,
|
"code": None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
self.logger.error(f"请求千问 API 时发生错误:{data}")
|
||||||
|
return data
|
||||||
else:
|
else:
|
||||||
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
|
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
|
||||||
result += embeddings
|
result += embeddings
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ from fastchat.conversation import Conversation
|
||||||
from server.model_workers.base import *
|
from server.model_workers.base import *
|
||||||
from server.utils import get_httpx_client
|
from server.utils import get_httpx_client
|
||||||
from fastchat import conversation as conv
|
from fastchat import conversation as conv
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
from typing import List, Literal, Dict
|
from typing import List, Literal, Dict
|
||||||
import requests
|
import requests
|
||||||
|
|
@ -64,12 +63,12 @@ class TianGongWorker(ApiModelWorker):
|
||||||
"text": text
|
"text": text
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
yield {
|
data = {
|
||||||
"error_code": resp["code"],
|
"error_code": resp["code"],
|
||||||
"text": resp["code_msg"]
|
"text": resp["code_msg"]
|
||||||
}
|
}
|
||||||
|
self.logger.error(f"请求天工 API 时出错:{data}")
|
||||||
|
yield data
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
# TODO: 支持embeddings
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ class XingHuoWorker(ApiModelWorker):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
except:
|
except:
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
params.max_tokens = min(details["max_tokens"], params.max_tokens)
|
params.max_tokens = min(details["max_tokens"], params.max_tokens or 0)
|
||||||
for chunk in iter_over_async(
|
for chunk in iter_over_async(
|
||||||
request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages,
|
request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages,
|
||||||
params.temperature, params.max_tokens),
|
params.temperature, params.max_tokens),
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class ChatGLMWorker(ApiModelWorker):
|
||||||
if e.event == "add":
|
if e.event == "add":
|
||||||
yield {"error_code": 0, "text": e.data}
|
yield {"error_code": 0, "text": e.data}
|
||||||
elif e.event in ["error", "interrupted"]:
|
elif e.event in ["error", "interrupted"]:
|
||||||
yield {
|
data = {
|
||||||
"error_code": 500,
|
"error_code": 500,
|
||||||
"text": str(e),
|
"text": str(e),
|
||||||
"error": {
|
"error": {
|
||||||
|
|
@ -54,6 +54,8 @@ class ChatGLMWorker(ApiModelWorker):
|
||||||
"code": None,
|
"code": None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
self.logger.error(f"请求智谱 API 时发生错误:{data}")
|
||||||
|
yield data
|
||||||
|
|
||||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||||
import zhipuai
|
import zhipuai
|
||||||
|
|
@ -68,9 +70,12 @@ class ChatGLMWorker(ApiModelWorker):
|
||||||
if response["code"] == 200:
|
if response["code"] == 200:
|
||||||
embeddings.append(response["data"]["embedding"])
|
embeddings.append(response["data"]["embedding"])
|
||||||
else:
|
else:
|
||||||
|
self.logger.error(f"请求智谱 API 时发生错误:{response}")
|
||||||
return response # dict with code & msg
|
return response # dict with code & msg
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"code": 500, "msg": f"对文本向量化时出错:{e}"}
|
self.logger.error(f"请求智谱 API 时发生错误:{data}")
|
||||||
|
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
|
||||||
|
return data
|
||||||
|
|
||||||
return {"code": 200, "data": embeddings}
|
return {"code": 200, "data": embeddings}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ for x in list_config_llm_models()["online"]:
|
||||||
workers.append(x)
|
workers.append(x)
|
||||||
print(f"all workers to test: {workers}")
|
print(f"all workers to test: {workers}")
|
||||||
|
|
||||||
# workers = ["qianfan-api"]
|
# workers = ["fangzhou-api"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("worker", workers)
|
@pytest.mark.parametrize("worker", workers)
|
||||||
|
|
@ -28,11 +28,11 @@ def test_chat(worker):
|
||||||
)
|
)
|
||||||
print(f"\nchat with {worker} \n")
|
print(f"\nchat with {worker} \n")
|
||||||
|
|
||||||
worker_class = get_model_worker_config(worker)["worker_class"]
|
if worker_class := get_model_worker_config(worker).get("worker_class"):
|
||||||
for x in worker_class().do_chat(params):
|
for x in worker_class().do_chat(params):
|
||||||
pprint(x)
|
pprint(x)
|
||||||
assert isinstance(x, dict)
|
assert isinstance(x, dict)
|
||||||
assert x["error_code"] == 0
|
assert x["error_code"] == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("worker", workers)
|
@pytest.mark.parametrize("worker", workers)
|
||||||
|
|
@ -44,19 +44,19 @@ def test_embeddings(worker):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
worker_class = get_model_worker_config(worker)["worker_class"]
|
if worker_class := get_model_worker_config(worker).get("worker_class"):
|
||||||
if worker_class.can_embedding():
|
if worker_class.can_embedding():
|
||||||
print(f"\embeddings with {worker} \n")
|
print(f"\embeddings with {worker} \n")
|
||||||
resp = worker_class().do_embeddings(params)
|
resp = worker_class().do_embeddings(params)
|
||||||
|
|
||||||
pprint(resp, depth=2)
|
pprint(resp, depth=2)
|
||||||
assert resp["code"] == 200
|
assert resp["code"] == 200
|
||||||
assert "data" in resp
|
assert "data" in resp
|
||||||
embeddings = resp["data"]
|
embeddings = resp["data"]
|
||||||
assert isinstance(embeddings, list) and len(embeddings) > 0
|
assert isinstance(embeddings, list) and len(embeddings) > 0
|
||||||
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
||||||
assert isinstance(embeddings[0][0], float)
|
assert isinstance(embeddings[0][0], float)
|
||||||
print("向量长度:", len(embeddings[0]))
|
print("向量长度:", len(embeddings[0]))
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("worker", workers)
|
# @pytest.mark.parametrize("worker", workers)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue