修复文心一言,添加测试用例

This commit is contained in:
liunux4odoo 2023-09-14 23:37:34 +08:00
parent fbaca1009e
commit 4cf2e5ea5e
6 changed files with 170 additions and 125 deletions

View File

@ -103,9 +103,9 @@ llm_model_dict = {
"is_v2": False, "is_v2": False,
"provider": "XingHuoWorker", "provider": "XingHuoWorker",
}, },
# Ernie Bot Turbo API申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf # 百度千帆 API申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
"ernie-api": { "qianfan-api": {
"version": "ernie-bot-turbo", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo" "version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo"
"api_base_url": "http://127.0.0.1:8888/v1", "api_base_url": "http://127.0.0.1:8888/v1",
"api_key": "", "api_key": "",
"secret_key": "", "secret_key": "",

View File

@ -59,6 +59,7 @@ FSCHAT_MODEL_WORKERS = {
# "limit_worker_concurrency": 5, # "limit_worker_concurrency": 5,
# "stream_interval": 2, # "stream_interval": 2,
# "no_register": False, # "no_register": False,
# "embed_in_truncate": False,
}, },
"baichuan-7b": { # 使用default中的IP和端口 "baichuan-7b": { # 使用default中的IP和端口
"device": "cpu", "device": "cpu",
@ -72,6 +73,9 @@ FSCHAT_MODEL_WORKERS = {
"xinghuo-api": { # 请为每个在线API设置不同的端口 "xinghuo-api": { # 请为每个在线API设置不同的端口
"port": 20005, "port": 20005,
}, },
"qianfan-api": {
"port": 20006,
},
} }
# fastchat multi model worker server # fastchat multi model worker server

View File

@ -1,4 +1,4 @@
from .zhipu import ChatGLMWorker from .zhipu import ChatGLMWorker
from .minimax import MiniMaxWorker from .minimax import MiniMaxWorker
from .xinghuo import XingHuoWorker from .xinghuo import XingHuoWorker
from .ernie import ErnieWorker from .qianfan import QianFanWorker

View File

@ -1,121 +0,0 @@
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
import json
import requests
from typing import List, Literal
MODEL_VERSIONS = {
"ernie-bot": "completions",
"ernie-bot-turbo": "eb-instant"
}
class ErnieWorker(ApiModelWorker):
"""
百度 Ernie
"""
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\
'/{model_version}?access_token={access_token}'
SUPPORT_MODELS = list(MODEL_VERSIONS.keys())
def __init__(
self,
*,
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
model_names: List[str] = ["ernie-api"],
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
config = self.get_config()
self.version = version
self.api_key = config.get("api_key")
self.secret_key = config.get("secret_key")
self.access_token = self.get_access_token()
def get_access_token(self):
"""
使用 API KeySecret Key 获取access_token替换下列示例中的应用API Key应用Secret Key
"""
url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials"\
f"&client_id={self.api_key}"\
f"&client_secret={self.secret_key}"
payload = json.dumps("")
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
return response.json().get("access_token")
def generate_stream_gate(self, params):
url = self.BASE_URL.format(
model_version=MODEL_VERSIONS[self.version],
access_token=self.access_token
)
payload = json.dumps({
"messages": self.prompt_to_messages(params["prompt"]),
"stream": True
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload, stream=True)
text=""
for line in response.iter_lines():
if line.decode("utf-8").startswith("data: "): # 真是优秀的返回
resp = json.loads(line.decode("utf-8")[6:])
if "result" in resp.keys():
text += resp["result"]
yield json.dumps({
"error_code": 0,
"text": text
},
ensure_ascii=False
).encode() + b"\0"
else:
yield json.dumps({
"error_code": resp["error_code"],
"text": resp["error_msg"]
},
ensure_ascii=False
).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = EnrieWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20003",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20003)

View File

@ -0,0 +1,142 @@
from server.model_workers.base import ApiModelWorker
from configs.model_config import TEMPERATURE
from fastchat import conversation as conv
import sys
import json
import httpx
from cachetools import cached, TTLCache
from server.utils import get_model_worker_config
from typing import List, Literal, Dict
# TODO: support all qianfan models
MODEL_VERSIONS = {
"ernie-bot": "completions",
"ernie-bot-turbo": "eb-instant",
}
@cached(TTLCache(1, 1800)) # 经过测试缓存的token可以使用目前每30分钟刷新一次
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
"""
使用 AKSK 生成鉴权签名Access Token
:return: access_token或是None(如果错误)
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
try:
return httpx.get(url, params=params).json().get("access_token")
except Exception as e:
print(f"failed to get token from baidu: {e}")
def request_qianfan_api(
messages: List[Dict[str, str]],
temperature: float = TEMPERATURE,
model_name: str = "qianfan-api",
version: str = None,
) -> Dict:
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\
'/{model_version}?access_token={access_token}'
config = get_model_worker_config(model_name)
version = version or config.get("version")
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
if not access_token:
raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?")
url = BASE_URL.format(
model_version=MODEL_VERSIONS[version],
access_token=access_token,
)
payload = {
"messages": messages,
"temperature": temperature,
"stream": True
}
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
}
with httpx.stream("POST", url, headers=headers, json=payload) as response:
for line in response.iter_lines():
if not line.strip():
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
yield resp
class QianFanWorker(ApiModelWorker):
"""
百度千帆
"""
def __init__(
self,
*,
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
model_names: List[str] = ["ernie-api"],
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
config = self.get_config()
self.version = version
self.api_key = config.get("api_key")
self.secret_key = config.get("secret_key")
def generate_stream_gate(self, params):
messages = self.prompt_to_messages(params["prompt"])
text=""
for resp in request_qianfan_api(messages,
temperature=params.get("temperature"),
model_name=self.model_names[0]):
if "result" in resp.keys():
text += resp["result"]
yield json.dumps({
"error_code": 0,
"text": text
},
ensure_ascii=False
).encode() + b"\0"
else:
yield json.dumps({
"error_code": resp["error_code"],
"text": resp["error_msg"]
},
ensure_ascii=False
).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = QianFanWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20006",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20006)

View File

@ -0,0 +1,20 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.qianfan import request_qianfan_api, MODEL_VERSIONS
from pprint import pprint
import pytest
@pytest.mark.parametrize("version", MODEL_VERSIONS.keys())
def test_qianfan(version):
messages = [{"role": "user", "content": "你好"}]
print("\n" + version + "\n")
i = 1
for x in request_qianfan_api(messages, version=version):
pprint(x)
assert isinstance(x, dict)
assert "error_code" not in x
i += 1