diff --git a/server/model_workers/baichuan.py b/server/model_workers/baichuan.py index 879a51e..1c6a6f1 100644 --- a/server/model_workers/baichuan.py +++ b/server/model_workers/baichuan.py @@ -1,15 +1,15 @@ # import os # import sys # sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) -import requests import json import time import hashlib from server.model_workers.base import ApiModelWorker +from server.utils import get_model_worker_config, get_httpx_client from fastchat import conversation as conv import sys import json -from typing import List, Literal +from typing import List, Literal, Dict from configs import TEMPERATURE @@ -20,29 +20,29 @@ def calculate_md5(input_string): return encrypted -def do_request(): - url = "https://api.baichuan-ai.com/v1/stream/chat" - api_key = "" - secret_key = "" +def request_baichuan_api( + messages: List[Dict[str, str]], + api_key: str = None, + secret_key: str = None, + version: str = "Baichuan2-53B", + temperature: float = TEMPERATURE, + model_name: str = "baichuan-api", +): + config = get_model_worker_config(model_name) + api_key = api_key or config.get("api_key") + secret_key = secret_key or config.get("secret_key") + version = version or config.get("version") + url = "https://api.baichuan-ai.com/v1/stream/chat" data = { - "model": "Baichuan2-53B", - "messages": [ - { - "role": "user", - "content": "世界第一高峰是" - } - ], - "parameters": { - "temperature": 0.1, - "top_k": 10 - } + "model": version, + "messages": messages, + "parameters": {"temperature": temperature} } json_data = json.dumps(data) time_stamp = int(time.time()) signature = calculate_md5(secret_key + json_data + str(time_stamp)) - headers = { "Content-Type": "application/json", "Authorization": "Bearer " + api_key, @@ -52,18 +52,17 @@ def do_request(): "X-BC-Sign-Algo": "MD5", } - response = requests.post(url, data=json_data, headers=headers) - - if response.status_code == 200: - print("请求成功!") - print("响应header:", response.headers) - print("响应body:", response.text) - else: - print("请求失败,状态码:", response.status_code) + with get_httpx_client() as client: + with client.stream("POST", url, headers=headers, json=data) as response: + for line in response.iter_lines(): + if not line.strip(): + continue + resp = json.loads(line) + yield resp class BaiChuanWorker(ApiModelWorker): - BASE_URL = "https://api.baichuan-ai.com/v1/chat" + BASE_URL = "https://api.baichuan-ai.com/v1/stream/chat" SUPPORT_MODELS = ["Baichuan2-53B"] def __init__( @@ -95,54 +94,34 @@ class BaiChuanWorker(ApiModelWorker): self.secret_key = config.get("secret_key") def generate_stream_gate(self, params): - data = { - "model": self.version, - "messages": [ - { - "role": "user", - "content": params["prompt"] - } - ], - "parameters": { - "temperature": params.get("temperature",TEMPERATURE), - "top_k": params.get("top_k",1) - } - } + super().generate_stream_gate(params) - json_data = json.dumps(data) - time_stamp = int(time.time()) - signature = calculate_md5(self.secret_key + json_data + str(time_stamp)) - headers = { - "Content-Type": "application/json", - "Authorization": "Bearer " + self.api_key, - "X-BC-Request-Id": "your requestId", - "X-BC-Timestamp": str(time_stamp), - "X-BC-Signature": signature, - "X-BC-Sign-Algo": "MD5", - } + messages = self.prompt_to_messages(params["prompt"]) - response = requests.post(self.BASE_URL, data=json_data, headers=headers) + text = "" + for resp in request_baichuan_api(messages=messages, + api_key=self.api_key, + secret_key=self.secret_key, + version=self.version, + temperature=params.get("temperature")): + if resp["code"] == 0: + text += resp["data"]["messages"][-1]["content"] + yield json.dumps( + { + "error_code": resp["code"], + "text": text + }, + ensure_ascii=False + ).encode() + b"\0" + else: + yield json.dumps( + { + "error_code": resp["code"], + "text": resp["msg"] + }, + ensure_ascii=False + ).encode() + b"\0" - if response.status_code == 200: - resp = eval(response.text) - yield json.dumps( - { - "error_code": resp["code"], - "text": resp["data"]["messages"][-1]["content"] - }, - ensure_ascii=False - ).encode() + b"\0" - else: - yield json.dumps( - { - "error_code": resp["code"], - "text": resp["msg"] - }, - ensure_ascii=False - ).encode() + b"\0" - - - def get_embeddings(self, params): # TODO: 支持embeddings print("embedding") diff --git a/tests/online_api/test_baichuan.py b/tests/online_api/test_baichuan.py new file mode 100644 index 0000000..536466e --- /dev/null +++ b/tests/online_api/test_baichuan.py @@ -0,0 +1,16 @@ +import sys +from pathlib import Path +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) + +from server.model_workers.baichuan import request_baichuan_api +from pprint import pprint + + +def test_qwen(): + messages = [{"role": "user", "content": "hello"}] + + for x in request_baichuan_api(messages): + print(type(x)) + pprint(x) + assert x["code"] == 0 \ No newline at end of file