From 0e20552083561314ef47b63cddb7641b83b31c3b Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Fri, 29 Sep 2023 13:16:14 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E5=A2=9E=E5=8A=A0baichuan-api=E6=94=AF?= =?UTF-8?q?=E6=8C=81=EF=BC=9B2.=E5=A2=9E=E5=8A=A0=E6=89=B9=E9=87=8F?= =?UTF-8?q?=E5=A4=8D=E5=88=B6configs=E4=B8=8B.example=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E4=B8=BA.py=E6=96=87=E4=BB=B6=E7=9A=84=E8=84=9A=E6=9C=ACcopy?= =?UTF-8?q?=5Fconfig=5Fexample.py;3.=20=E6=9B=B4=E6=96=B0model=5Fconfig.py?= =?UTF-8?q?.example?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py.example | 8 ++ copy_config_example.py | 13 +++ server/model_workers/baichuan.py | 163 +++++++++++++++++++++++++++++++ 3 files changed, 184 insertions(+) create mode 100644 copy_config_example.py create mode 100644 server/model_workers/baichuan.py diff --git a/configs/model_config.py.example b/configs/model_config.py.example index f5e8465..464e01b 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -174,6 +174,14 @@ ONLINE_LLM_MODEL = { "api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建 "provider": "QwenWorker", }, + + # 百川 API,申请方式请参考 https://www.baichuan-ai.com/home#api-enter + "baichuan-api": { + "version": "Baichuan2-53B", # 当前支持 "Baichuan2-53B", 见官方文档。 + "api_key": "", + "secret_key": "", + "provider": "BaiChuanWorker", + }, } diff --git a/copy_config_example.py b/copy_config_example.py new file mode 100644 index 0000000..0bba881 --- /dev/null +++ b/copy_config_example.py @@ -0,0 +1,13 @@ +# 用于批量将configs下的.example文件复制并命名为.py文件 +import os +import shutil +files = os.listdir("configs") + +src_files = [os.path.join("configs",file) for file in files if ".example" in file] + +for src_file in src_files: + tar_file = src_file.replace(".example","") + shutil.copy(src_file,tar_file) + + + diff --git a/server/model_workers/baichuan.py b/server/model_workers/baichuan.py new file mode 100644 index 0000000..7d68586 --- /dev/null +++ b/server/model_workers/baichuan.py @@ -0,0 +1,163 @@ +# import os +# import sys +# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +import requests +import json +import time +import hashlib +from server.model_workers.base import ApiModelWorker +from fastchat import conversation as conv +import sys +import json +from typing import List, Literal +from configs import TEMPERATURE + + +def calculate_md5(input_string): + md5 = hashlib.md5() + md5.update(input_string.encode('utf-8')) + encrypted = md5.hexdigest() + return encrypted + + +def do_request(): + url = "https://api.baichuan-ai.com/v1/stream/chat" + api_key = "" + secret_key = "" + + data = { + "model": "Baichuan2-53B", + "messages": [ + { + "role": "user", + "content": "世界第一高峰是" + } + ], + "parameters": { + "temperature": 0.1, + "top_k": 10 + } + } + + json_data = json.dumps(data) + time_stamp = int(time.time()) + signature = calculate_md5(secret_key + json_data + str(time_stamp)) + + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + api_key, + "X-BC-Request-Id": "your requestId", + "X-BC-Timestamp": str(time_stamp), + "X-BC-Signature": signature, + "X-BC-Sign-Algo": "MD5", + } + + response = requests.post(url, data=json_data, headers=headers) + + if response.status_code == 200: + print("请求成功!") + print("响应header:", response.headers) + print("响应body:", response.text) + else: + print("请求失败,状态码:", response.status_code) + + +class BaiChuanWorker(ApiModelWorker): + BASE_URL = "https://api.baichuan-ai.com/v1/chat" + SUPPORT_MODELS = ["Baichuan2-53B"] + + def __init__( + self, + *, + controller_addr: str, + worker_addr: str, + model_names: List[str] = ["baichuan-api"], + version: Literal["Baichuan2-53B"] = "Baichuan2-53B", + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 32768) + super().__init__(**kwargs) + + # TODO: 确认模板是否需要修改 + self.conv = conv.Conversation( + name=self.model_names[0], + system_message="", + messages=[], + roles=["user", "assistant"], + sep="\n### ", + stop_str="###", + ) + + config = self.get_config() + self.version = config.get("version",version) + self.api_key = config.get("api_key") + self.secret_key = config.get("secret_key") + + def generate_stream_gate(self, params): + data = { + "model": self.version, + "messages": [ + { + "role": "user", + "content": params["prompt"] + } + ], + "parameters": { + "temperature": params.get("temperature",TEMPERATURE), + "top_k": params.get("top_k",1) + } + } + + json_data = json.dumps(data) + time_stamp = int(time.time()) + signature = calculate_md5(self.secret_key + json_data + str(time_stamp)) + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + self.api_key, + "X-BC-Request-Id": "your requestId", + "X-BC-Timestamp": str(time_stamp), + "X-BC-Signature": signature, + "X-BC-Sign-Algo": "MD5", + } + + response = requests.post(self.BASE_URL, data=json_data, headers=headers) + + if response.status_code == 200: + resp = eval(response.text) + yield json.dumps( + { + "error_code": resp["code"], + "text": resp["data"]["messages"][-1]["content"] + }, + ensure_ascii=False + ).encode() + b"\0" + else: + yield json.dumps( + { + "error_code": resp["code"], + "text": resp["msg"] + }, + ensure_ascii=False + ).encode() + b"\0" + + + + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + +if __name__ == "__main__": + # import uvicorn + # from server.utils import MakeFastAPIOffline + # from fastchat.serve.model_worker import app + + # worker = BaiChuanWorker( + # controller_addr="http://127.0.0.1:20001", + # worker_addr="http://127.0.0.1:21001", + # ) + # sys.modules["fastchat.serve.model_worker"].worker = worker + # MakeFastAPIOffline(app) + # uvicorn.run(app, port=21001) + # do_request() \ No newline at end of file