补充更多千帆平台支持的模型;除了指定模型名称,支持直接指定模型APIURL,便于填写单独申请的模型地址
This commit is contained in:
parent
4cf2e5ea5e
commit
f0f1dc2537
|
|
@ -105,7 +105,8 @@ llm_model_dict = {
|
||||||
},
|
},
|
||||||
# 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
|
# 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
|
||||||
"qianfan-api": {
|
"qianfan-api": {
|
||||||
"version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo"
|
"version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo", 更多的见文档模型支持列表中千帆部分。
|
||||||
|
"version_url": "", # 可以不填写version,直接填写在千帆申请模型发布的API地址
|
||||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
"api_base_url": "http://127.0.0.1:8888/v1",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"secret_key": "",
|
"secret_key": "",
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,39 @@ from server.utils import get_model_worker_config
|
||||||
from typing import List, Literal, Dict
|
from typing import List, Literal, Dict
|
||||||
|
|
||||||
|
|
||||||
# TODO: support all qianfan models
|
|
||||||
MODEL_VERSIONS = {
|
MODEL_VERSIONS = {
|
||||||
"ernie-bot": "completions",
|
"ernie-bot": "completions",
|
||||||
"ernie-bot-turbo": "eb-instant",
|
"ernie-bot-turbo": "eb-instant",
|
||||||
|
"bloomz-7b": "bloomz_7b1",
|
||||||
|
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
|
||||||
|
"llama2-7b-chat": "llama_2_7b",
|
||||||
|
"llama2-13b-chat": "llama_2_13b",
|
||||||
|
"llama2-70b-chat": "llama_2_70b",
|
||||||
|
"qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b",
|
||||||
|
"chatglm2-6b-32k": "chatglm2_6b_32k",
|
||||||
|
"aquilachat-7b": "aquilachat_7b",
|
||||||
|
# "linly-llama2-ch-7b": "", # 暂未发布
|
||||||
|
# "linly-llama2-ch-13b": "", # 暂未发布
|
||||||
|
# "chatglm2-6b": "", # 暂未发布
|
||||||
|
# "chatglm2-6b-int4": "", # 暂未发布
|
||||||
|
# "falcon-7b": "", # 暂未发布
|
||||||
|
# "falcon-180b-chat": "", # 暂未发布
|
||||||
|
# "falcon-40b": "", # 暂未发布
|
||||||
|
# "rwkv4-world": "", # 暂未发布
|
||||||
|
# "rwkv5-world": "", # 暂未发布
|
||||||
|
# "rwkv4-pile-14b": "", # 暂未发布
|
||||||
|
# "rwkv4-raven-14b": "", # 暂未发布
|
||||||
|
# "open-llama-7b": "", # 暂未发布
|
||||||
|
# "dolly-12b": "", # 暂未发布
|
||||||
|
# "mpt-7b-instruct": "", # 暂未发布
|
||||||
|
# "mpt-30b-instruct": "", # 暂未发布
|
||||||
|
# "OA-Pythia-12B-SFT-4": "", # 暂未发布
|
||||||
|
# "xverse-13b": "", # 暂未发布
|
||||||
|
|
||||||
|
# # 以下为企业测试,需要单独申请
|
||||||
|
# "flan-ul2": "",
|
||||||
|
# "Cerebras-GPT-6.7B": ""
|
||||||
|
# "Pythia-6.9B": ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -40,12 +69,13 @@ def request_qianfan_api(
|
||||||
'/{model_version}?access_token={access_token}'
|
'/{model_version}?access_token={access_token}'
|
||||||
config = get_model_worker_config(model_name)
|
config = get_model_worker_config(model_name)
|
||||||
version = version or config.get("version")
|
version = version or config.get("version")
|
||||||
|
version_url = config.get("version_url")
|
||||||
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
|
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
|
||||||
if not access_token:
|
if not access_token:
|
||||||
raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?")
|
raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?")
|
||||||
|
|
||||||
url = BASE_URL.format(
|
url = BASE_URL.format(
|
||||||
model_version=MODEL_VERSIONS[version],
|
model_version=version_url or MODEL_VERSIONS[version],
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
payload = {
|
payload = {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue