修复科大讯飞token问题和Azure的token问题 (#1894)
Co-authored-by: zR <zRzRzRzRzRzRzR>
This commit is contained in:
parent
b68f7fcdea
commit
6ed87954b2
|
|
@ -18,7 +18,7 @@ class AzureWorker(ApiModelWorker):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||||
kwargs.setdefault("context_len", 8192)
|
kwargs.setdefault("context_len", 8000) #TODO 16K模型需要改成16384
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.version = version
|
self.version = version
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ import sys
|
||||||
from server.model_workers.base import ApiEmbeddingsParams
|
from server.model_workers.base import ApiEmbeddingsParams
|
||||||
from typing import List, Literal, Dict
|
from typing import List, Literal, Dict
|
||||||
|
|
||||||
|
|
||||||
MODEL_VERSIONS = {
|
MODEL_VERSIONS = {
|
||||||
"ernie-bot-4": "completions_pro",
|
"ernie-bot-4": "completions_pro",
|
||||||
"ernie-bot": "completions",
|
"ernie-bot": "completions",
|
||||||
|
|
@ -47,7 +46,7 @@ MODEL_VERSIONS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
||||||
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
||||||
"""
|
"""
|
||||||
使用 AK,SK 生成鉴权签名(Access Token)
|
使用 AK,SK 生成鉴权签名(Access Token)
|
||||||
|
|
@ -69,13 +68,13 @@ class QianFanWorker(ApiModelWorker):
|
||||||
DEFAULT_EMBED_MODEL = "embedding-v1"
|
DEFAULT_EMBED_MODEL = "embedding-v1"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
|
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
|
||||||
model_names: List[str] = ["qianfan-api"],
|
model_names: List[str] = ["qianfan-api"],
|
||||||
controller_addr: str = None,
|
controller_addr: str = None,
|
||||||
worker_addr: str = None,
|
worker_addr: str = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||||
kwargs.setdefault("context_len", 16384)
|
kwargs.setdefault("context_len", 16384)
|
||||||
|
|
@ -108,8 +107,8 @@ class QianFanWorker(ApiModelWorker):
|
||||||
# "text": str(resp.body),
|
# "text": str(resp.body),
|
||||||
# }
|
# }
|
||||||
|
|
||||||
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\
|
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
|
||||||
'/{model_version}?access_token={access_token}'
|
'/{model_version}?access_token={access_token}'
|
||||||
|
|
||||||
access_token = get_baidu_access_token(params.api_key, params.secret_key)
|
access_token = get_baidu_access_token(params.api_key, params.secret_key)
|
||||||
if not access_token:
|
if not access_token:
|
||||||
|
|
@ -177,7 +176,6 @@ class QianFanWorker(ApiModelWorker):
|
||||||
else:
|
else:
|
||||||
return {"code": resp["error_code"], "msg": resp["error_msg"]}
|
return {"code": resp["error_code"], "msg": resp["error_msg"]}
|
||||||
|
|
||||||
|
|
||||||
# TODO: qianfan支持续写模型
|
# TODO: qianfan支持续写模型
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
# TODO: 支持embeddings
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ async def request(appid, api_key, api_secret, Spark_url, domain, question, tempe
|
||||||
wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url)
|
wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url)
|
||||||
wsUrl = wsParam.create_url()
|
wsUrl = wsParam.create_url()
|
||||||
data = SparkApi.gen_params(appid, domain, question, temperature, max_token)
|
data = SparkApi.gen_params(appid, domain, question, temperature, max_token)
|
||||||
|
print(data)
|
||||||
async with websockets.connect(wsUrl) as ws:
|
async with websockets.connect(wsUrl) as ws:
|
||||||
await ws.send(json.dumps(data, ensure_ascii=False))
|
await ws.send(json.dumps(data, ensure_ascii=False))
|
||||||
finish = False
|
finish = False
|
||||||
|
|
@ -36,7 +37,7 @@ class XingHuoWorker(ApiModelWorker):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||||
kwargs.setdefault("context_len", 8192)
|
kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000,需要自行修改
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.version = version
|
self.version = version
|
||||||
|
|
||||||
|
|
@ -45,15 +46,14 @@ class XingHuoWorker(ApiModelWorker):
|
||||||
params.load_config(self.model_names[0])
|
params.load_config(self.model_names[0])
|
||||||
|
|
||||||
version_mapping = {
|
version_mapping = {
|
||||||
"v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat","max_tokens": 2048},
|
"v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat","max_tokens": 4000},
|
||||||
"v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat","max_tokens": 4096},
|
"v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat","max_tokens": 8000},
|
||||||
"v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat","max_tokens": 8192},
|
"v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat","max_tokens": 8000},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_version_details(version_key):
|
def get_version_details(version_key):
|
||||||
return version_mapping.get(version_key, {"domain": None, "url": None})
|
return version_mapping.get(version_key, {"domain": None, "url": None})
|
||||||
|
|
||||||
# 使用方法:
|
|
||||||
details = get_version_details(params.version)
|
details = get_version_details(params.version)
|
||||||
domain = details["domain"]
|
domain = details["domain"]
|
||||||
Spark_url = details["url"]
|
Spark_url = details["url"]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue