修复科大讯飞token问题和Azure的token问题 (#1894)
Co-authored-by: zR <zRzRzRzRzRzRzR>
This commit is contained in:
parent
b68f7fcdea
commit
6ed87954b2
|
|
@ -18,7 +18,7 @@ class AzureWorker(ApiModelWorker):
|
|||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 8192)
|
||||
kwargs.setdefault("context_len", 8000) #TODO 16K模型需要改成16384
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import sys
|
|||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
from typing import List, Literal, Dict
|
||||
|
||||
|
||||
MODEL_VERSIONS = {
|
||||
"ernie-bot-4": "completions_pro",
|
||||
"ernie-bot": "completions",
|
||||
|
|
@ -47,7 +46,7 @@ MODEL_VERSIONS = {
|
|||
}
|
||||
|
||||
|
||||
@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
||||
@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
||||
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
使用 AK,SK 生成鉴权签名(Access Token)
|
||||
|
|
@ -69,13 +68,13 @@ class QianFanWorker(ApiModelWorker):
|
|||
DEFAULT_EMBED_MODEL = "embedding-v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
|
||||
model_names: List[str] = ["qianfan-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
**kwargs,
|
||||
self,
|
||||
*,
|
||||
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
|
||||
model_names: List[str] = ["qianfan-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
|
|
@ -108,8 +107,8 @@ class QianFanWorker(ApiModelWorker):
|
|||
# "text": str(resp.body),
|
||||
# }
|
||||
|
||||
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\
|
||||
'/{model_version}?access_token={access_token}'
|
||||
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
|
||||
'/{model_version}?access_token={access_token}'
|
||||
|
||||
access_token = get_baidu_access_token(params.api_key, params.secret_key)
|
||||
if not access_token:
|
||||
|
|
@ -131,7 +130,7 @@ class QianFanWorker(ApiModelWorker):
|
|||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
}
|
||||
|
||||
|
||||
text = ""
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
|
|
@ -177,7 +176,6 @@ class QianFanWorker(ApiModelWorker):
|
|||
else:
|
||||
return {"code": resp["error_code"], "msg": resp["error_msg"]}
|
||||
|
||||
|
||||
# TODO: qianfan支持续写模型
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
@ -207,4 +205,4 @@ if __name__ == "__main__":
|
|||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21004)
|
||||
uvicorn.run(app, port=21004)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ async def request(appid, api_key, api_secret, Spark_url, domain, question, tempe
|
|||
wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url)
|
||||
wsUrl = wsParam.create_url()
|
||||
data = SparkApi.gen_params(appid, domain, question, temperature, max_token)
|
||||
print(data)
|
||||
async with websockets.connect(wsUrl) as ws:
|
||||
await ws.send(json.dumps(data, ensure_ascii=False))
|
||||
finish = False
|
||||
|
|
@ -36,7 +37,7 @@ class XingHuoWorker(ApiModelWorker):
|
|||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 8192)
|
||||
kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000,需要自行修改
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
|
|
@ -45,15 +46,14 @@ class XingHuoWorker(ApiModelWorker):
|
|||
params.load_config(self.model_names[0])
|
||||
|
||||
version_mapping = {
|
||||
"v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat","max_tokens": 2048},
|
||||
"v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat","max_tokens": 4096},
|
||||
"v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat","max_tokens": 8192},
|
||||
"v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat","max_tokens": 4000},
|
||||
"v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat","max_tokens": 8000},
|
||||
"v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat","max_tokens": 8000},
|
||||
}
|
||||
|
||||
def get_version_details(version_key):
|
||||
return version_mapping.get(version_key, {"domain": None, "url": None})
|
||||
|
||||
# 使用方法:
|
||||
details = get_version_details(params.version)
|
||||
domain = details["domain"]
|
||||
Spark_url = details["url"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue