From 4cf2e5ea5ebcbc02b288fe6d63a1381375f3c8e7 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Thu, 14 Sep 2023 23:37:34 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=96=87=E5=BF=83=E4=B8=80?= =?UTF-8?q?=E8=A8=80=EF=BC=8C=E6=B7=BB=E5=8A=A0=E6=B5=8B=E8=AF=95=E7=94=A8?= =?UTF-8?q?=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py.example | 6 +- configs/server_config.py.example | 4 + server/model_workers/__init__.py | 2 +- server/model_workers/ernie.py | 121 -------------------------- server/model_workers/qianfan.py | 142 +++++++++++++++++++++++++++++++ tests/online_api/test_qianfan.py | 20 +++++ 6 files changed, 170 insertions(+), 125 deletions(-) delete mode 100644 server/model_workers/ernie.py create mode 100644 server/model_workers/qianfan.py create mode 100644 tests/online_api/test_qianfan.py diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 2057e10..b2256d3 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -103,9 +103,9 @@ llm_model_dict = { "is_v2": False, "provider": "XingHuoWorker", }, - # Ernie Bot Turbo API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf - "ernie-api": { - "version": "ernie-bot-turbo", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo" + # 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf + "qianfan-api": { + "version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo" "api_base_url": "http://127.0.0.1:8888/v1", "api_key": "", "secret_key": "", diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 3107c1d..60ed32e 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -59,6 +59,7 @@ FSCHAT_MODEL_WORKERS = { # "limit_worker_concurrency": 5, # "stream_interval": 2, # "no_register": False, + # "embed_in_truncate": False, }, "baichuan-7b": { # 使用default中的IP和端口 "device": "cpu", @@ -72,6 +73,9 @@ FSCHAT_MODEL_WORKERS = { "xinghuo-api": { # 请为每个在线API设置不同的端口 "port": 20005, }, + "qianfan-api": { + "port": 20006, + }, } # fastchat multi model worker server diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py index cb23a4a..a3a162f 100644 --- a/server/model_workers/__init__.py +++ b/server/model_workers/__init__.py @@ -1,4 +1,4 @@ from .zhipu import ChatGLMWorker from .minimax import MiniMaxWorker from .xinghuo import XingHuoWorker -from .ernie import ErnieWorker +from .qianfan import QianFanWorker diff --git a/server/model_workers/ernie.py b/server/model_workers/ernie.py deleted file mode 100644 index 3c9920b..0000000 --- a/server/model_workers/ernie.py +++ /dev/null @@ -1,121 +0,0 @@ -from server.model_workers.base import ApiModelWorker -from fastchat import conversation as conv -import sys -import json -import requests -from typing import List, Literal - -MODEL_VERSIONS = { - "ernie-bot": "completions", - "ernie-bot-turbo": "eb-instant" -} - - -class ErnieWorker(ApiModelWorker): - """ - 百度 Ernie - """ - BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\ - '/{model_version}?access_token={access_token}' - SUPPORT_MODELS = list(MODEL_VERSIONS.keys()) - - def __init__( - self, - *, - version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot", - model_names: List[str] = ["ernie-api"], - controller_addr: str, - worker_addr: str, - **kwargs, - ): - kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 16384) - super().__init__(**kwargs) - - # TODO: 确认模板是否需要修改 - self.conv = conv.Conversation( - name=self.model_names[0], - system_message="", - messages=[], - roles=["user", "assistant"], - sep="\n### ", - stop_str="###", - ) - - config = self.get_config() - self.version = version - self.api_key = config.get("api_key") - self.secret_key = config.get("secret_key") - self.access_token = self.get_access_token() - - def get_access_token(self): - """ - 使用 API Key,Secret Key 获取access_token,替换下列示例中的应用API Key、应用Secret Key - """ - - url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials"\ - f"&client_id={self.api_key}"\ - f"&client_secret={self.secret_key}" - - payload = json.dumps("") - headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json' - } - - response = requests.request("POST", url, headers=headers, data=payload) - return response.json().get("access_token") - - def generate_stream_gate(self, params): - url = self.BASE_URL.format( - model_version=MODEL_VERSIONS[self.version], - access_token=self.access_token - ) - payload = json.dumps({ - "messages": self.prompt_to_messages(params["prompt"]), - "stream": True - }) - headers = { - 'Content-Type': 'application/json' - } - - response = requests.request("POST", url, headers=headers, data=payload, stream=True) - - text="" - for line in response.iter_lines(): - if line.decode("utf-8").startswith("data: "): # 真是优秀的返回 - resp = json.loads(line.decode("utf-8")[6:]) - if "result" in resp.keys(): - text += resp["result"] - yield json.dumps({ - "error_code": 0, - "text": text - }, - ensure_ascii=False - ).encode() + b"\0" - else: - yield json.dumps({ - "error_code": resp["error_code"], - "text": resp["error_msg"] - }, - ensure_ascii=False - ).encode() + b"\0" - - def get_embeddings(self, params): - # TODO: 支持embeddings - print("embedding") - print(params) - - -if __name__ == "__main__": - import uvicorn - from server.utils import MakeFastAPIOffline - from fastchat.serve.model_worker import app - - worker = EnrieWorker( - controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20003", - ) - sys.modules["fastchat.serve.model_worker"].worker = worker - MakeFastAPIOffline(app) - uvicorn.run(app, port=20003) \ No newline at end of file diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py new file mode 100644 index 0000000..cbdd996 --- /dev/null +++ b/server/model_workers/qianfan.py @@ -0,0 +1,142 @@ +from server.model_workers.base import ApiModelWorker +from configs.model_config import TEMPERATURE +from fastchat import conversation as conv +import sys +import json +import httpx +from cachetools import cached, TTLCache +from server.utils import get_model_worker_config +from typing import List, Literal, Dict + + +# TODO: support all qianfan models +MODEL_VERSIONS = { + "ernie-bot": "completions", + "ernie-bot-turbo": "eb-instant", +} + + +@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次 +def get_baidu_access_token(api_key: str, secret_key: str) -> str: + """ + 使用 AK,SK 生成鉴权签名(Access Token) + :return: access_token,或是None(如果错误) + """ + url = "https://aip.baidubce.com/oauth/2.0/token" + params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} + try: + return httpx.get(url, params=params).json().get("access_token") + except Exception as e: + print(f"failed to get token from baidu: {e}") + + +def request_qianfan_api( + messages: List[Dict[str, str]], + temperature: float = TEMPERATURE, + model_name: str = "qianfan-api", + version: str = None, +) -> Dict: + BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\ + '/{model_version}?access_token={access_token}' + config = get_model_worker_config(model_name) + version = version or config.get("version") + access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key")) + if not access_token: + raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?") + + url = BASE_URL.format( + model_version=MODEL_VERSIONS[version], + access_token=access_token, + ) + payload = { + "messages": messages, + "temperature": temperature, + "stream": True + } + headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + } + + with httpx.stream("POST", url, headers=headers, json=payload) as response: + for line in response.iter_lines(): + if not line.strip(): + continue + if line.startswith("data: "): + line = line[6:] + resp = json.loads(line) + yield resp + + +class QianFanWorker(ApiModelWorker): + """ + 百度千帆 + """ + def __init__( + self, + *, + version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot", + model_names: List[str] = ["ernie-api"], + controller_addr: str, + worker_addr: str, + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 16384) + super().__init__(**kwargs) + + # TODO: 确认模板是否需要修改 + self.conv = conv.Conversation( + name=self.model_names[0], + system_message="", + messages=[], + roles=["user", "assistant"], + sep="\n### ", + stop_str="###", + ) + + config = self.get_config() + self.version = version + self.api_key = config.get("api_key") + self.secret_key = config.get("secret_key") + + def generate_stream_gate(self, params): + messages = self.prompt_to_messages(params["prompt"]) + text="" + for resp in request_qianfan_api(messages, + temperature=params.get("temperature"), + model_name=self.model_names[0]): + if "result" in resp.keys(): + text += resp["result"] + yield json.dumps({ + "error_code": 0, + "text": text + }, + ensure_ascii=False + ).encode() + b"\0" + else: + yield json.dumps({ + "error_code": resp["error_code"], + "text": resp["error_msg"] + }, + ensure_ascii=False + ).encode() + b"\0" + + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.model_worker import app + + worker = QianFanWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:20006", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=20006) \ No newline at end of file diff --git a/tests/online_api/test_qianfan.py b/tests/online_api/test_qianfan.py new file mode 100644 index 0000000..0e8a948 --- /dev/null +++ b/tests/online_api/test_qianfan.py @@ -0,0 +1,20 @@ +import sys +from pathlib import Path +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) + +from server.model_workers.qianfan import request_qianfan_api, MODEL_VERSIONS +from pprint import pprint +import pytest + + +@pytest.mark.parametrize("version", MODEL_VERSIONS.keys()) +def test_qianfan(version): + messages = [{"role": "user", "content": "你好"}] + print("\n" + version + "\n") + i = 1 + for x in request_qianfan_api(messages, version=version): + pprint(x) + assert isinstance(x, dict) + assert "error_code" not in x + i += 1