补充更多千帆平台支持的模型;除了指定模型名称,支持直接指定模型APIURL,便于填写单独申请的模型地址

This commit is contained in:
liunux4odoo 2023-09-15 00:30:18 +08:00
parent 4cf2e5ea5e
commit f0f1dc2537
2 changed files with 34 additions and 3 deletions

View File

@ -105,7 +105,8 @@ llm_model_dict = {
},
# 百度千帆 API申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
"qianfan-api": {
"version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo"
"version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo" 更多的见文档模型支持列表中千帆部分。
"version_url": "", # 可以不填写version直接填写在千帆申请模型发布的API地址
"api_base_url": "http://127.0.0.1:8888/v1",
"api_key": "",
"secret_key": "",

View File

@ -9,10 +9,39 @@ from server.utils import get_model_worker_config
from typing import List, Literal, Dict
# TODO: support all qianfan models
MODEL_VERSIONS = {
"ernie-bot": "completions",
"ernie-bot-turbo": "eb-instant",
"bloomz-7b": "bloomz_7b1",
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
"llama2-7b-chat": "llama_2_7b",
"llama2-13b-chat": "llama_2_13b",
"llama2-70b-chat": "llama_2_70b",
"qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b",
"chatglm2-6b-32k": "chatglm2_6b_32k",
"aquilachat-7b": "aquilachat_7b",
# "linly-llama2-ch-7b": "", # 暂未发布
# "linly-llama2-ch-13b": "", # 暂未发布
# "chatglm2-6b": "", # 暂未发布
# "chatglm2-6b-int4": "", # 暂未发布
# "falcon-7b": "", # 暂未发布
# "falcon-180b-chat": "", # 暂未发布
# "falcon-40b": "", # 暂未发布
# "rwkv4-world": "", # 暂未发布
# "rwkv5-world": "", # 暂未发布
# "rwkv4-pile-14b": "", # 暂未发布
# "rwkv4-raven-14b": "", # 暂未发布
# "open-llama-7b": "", # 暂未发布
# "dolly-12b": "", # 暂未发布
# "mpt-7b-instruct": "", # 暂未发布
# "mpt-30b-instruct": "", # 暂未发布
# "OA-Pythia-12B-SFT-4": "", # 暂未发布
# "xverse-13b": "", # 暂未发布
# # 以下为企业测试,需要单独申请
# "flan-ul2": "",
# "Cerebras-GPT-6.7B": ""
# "Pythia-6.9B": ""
}
@ -40,12 +69,13 @@ def request_qianfan_api(
'/{model_version}?access_token={access_token}'
config = get_model_worker_config(model_name)
version = version or config.get("version")
version_url = config.get("version_url")
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
if not access_token:
raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?")
url = BASE_URL.format(
model_version=MODEL_VERSIONS[version],
model_version=version_url or MODEL_VERSIONS[version],
access_token=access_token,
)
payload = {