修复科大讯飞token问题和Azure的token问题 (#1894)

Co-authored-by: zR <zRzRzRzRzRzRzR>
This commit is contained in:
zR 2023-10-27 13:51:59 +08:00 committed by GitHub
parent b68f7fcdea
commit 6ed87954b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 20 deletions

View File

@ -18,7 +18,7 @@ class AzureWorker(ApiModelWorker):
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8192)
kwargs.setdefault("context_len", 8000) #TODO 16K模型需要改成16384
super().__init__(**kwargs)
self.version = version

View File

@ -9,7 +9,6 @@ import sys
from server.model_workers.base import ApiEmbeddingsParams
from typing import List, Literal, Dict
MODEL_VERSIONS = {
"ernie-bot-4": "completions_pro",
"ernie-bot": "completions",
@ -47,7 +46,7 @@ MODEL_VERSIONS = {
}
@cached(TTLCache(1, 1800)) # 经过测试缓存的token可以使用目前每30分钟刷新一次
@cached(TTLCache(1, 1800)) # 经过测试缓存的token可以使用目前每30分钟刷新一次
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
"""
使用 AKSK 生成鉴权签名Access Token
@ -69,13 +68,13 @@ class QianFanWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "embedding-v1"
def __init__(
self,
*,
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
model_names: List[str] = ["qianfan-api"],
controller_addr: str = None,
worker_addr: str = None,
**kwargs,
self,
*,
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
model_names: List[str] = ["qianfan-api"],
controller_addr: str = None,
worker_addr: str = None,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
@ -108,8 +107,8 @@ class QianFanWorker(ApiModelWorker):
# "text": str(resp.body),
# }
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\
'/{model_version}?access_token={access_token}'
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
'/{model_version}?access_token={access_token}'
access_token = get_baidu_access_token(params.api_key, params.secret_key)
if not access_token:
@ -131,7 +130,7 @@ class QianFanWorker(ApiModelWorker):
'Content-Type': 'application/json',
'Accept': 'application/json',
}
text = ""
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=payload) as response:
@ -177,7 +176,6 @@ class QianFanWorker(ApiModelWorker):
else:
return {"code": resp["error_code"], "msg": resp["error_msg"]}
# TODO: qianfan支持续写模型
def get_embeddings(self, params):
# TODO: 支持embeddings
@ -207,4 +205,4 @@ if __name__ == "__main__":
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21004)
uvicorn.run(app, port=21004)

View File

@ -13,6 +13,7 @@ async def request(appid, api_key, api_secret, Spark_url, domain, question, tempe
wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url)
wsUrl = wsParam.create_url()
data = SparkApi.gen_params(appid, domain, question, temperature, max_token)
print(data)
async with websockets.connect(wsUrl) as ws:
await ws.send(json.dumps(data, ensure_ascii=False))
finish = False
@ -36,7 +37,7 @@ class XingHuoWorker(ApiModelWorker):
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8192)
kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000需要自行修改
super().__init__(**kwargs)
self.version = version
@ -45,15 +46,14 @@ class XingHuoWorker(ApiModelWorker):
params.load_config(self.model_names[0])
version_mapping = {
"v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat","max_tokens": 2048},
"v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat","max_tokens": 4096},
"v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat","max_tokens": 8192},
"v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat","max_tokens": 4000},
"v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat","max_tokens": 8000},
"v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat","max_tokens": 8000},
}
def get_version_details(version_key):
return version_mapping.get(version_key, {"domain": None, "url": None})
# 使用方法:
details = get_version_details(params.version)
domain = details["domain"]
Spark_url = details["url"]