update baichuan-api: 修正messages参数;支持流式;添加测试用例

This commit is contained in:
liunux4odoo 2023-10-20 19:09:05 +08:00
parent e920cd0064
commit 1d9d9df9e9
2 changed files with 67 additions and 72 deletions

View File

@ -1,15 +1,15 @@
# import os # import os
# import sys # import sys
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) # sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import requests
import json import json
import time import time
import hashlib import hashlib
from server.model_workers.base import ApiModelWorker from server.model_workers.base import ApiModelWorker
from server.utils import get_model_worker_config, get_httpx_client
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
import json import json
from typing import List, Literal from typing import List, Literal, Dict
from configs import TEMPERATURE from configs import TEMPERATURE
@ -20,29 +20,29 @@ def calculate_md5(input_string):
return encrypted return encrypted
def do_request(): def request_baichuan_api(
url = "https://api.baichuan-ai.com/v1/stream/chat" messages: List[Dict[str, str]],
api_key = "" api_key: str = None,
secret_key = "" secret_key: str = None,
version: str = "Baichuan2-53B",
temperature: float = TEMPERATURE,
model_name: str = "baichuan-api",
):
config = get_model_worker_config(model_name)
api_key = api_key or config.get("api_key")
secret_key = secret_key or config.get("secret_key")
version = version or config.get("version")
url = "https://api.baichuan-ai.com/v1/stream/chat"
data = { data = {
"model": "Baichuan2-53B", "model": version,
"messages": [ "messages": messages,
{ "parameters": {"temperature": temperature}
"role": "user",
"content": "世界第一高峰是"
}
],
"parameters": {
"temperature": 0.1,
"top_k": 10
}
} }
json_data = json.dumps(data) json_data = json.dumps(data)
time_stamp = int(time.time()) time_stamp = int(time.time())
signature = calculate_md5(secret_key + json_data + str(time_stamp)) signature = calculate_md5(secret_key + json_data + str(time_stamp))
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": "Bearer " + api_key, "Authorization": "Bearer " + api_key,
@ -52,18 +52,17 @@ def do_request():
"X-BC-Sign-Algo": "MD5", "X-BC-Sign-Algo": "MD5",
} }
response = requests.post(url, data=json_data, headers=headers) with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
if response.status_code == 200: for line in response.iter_lines():
print("请求成功!") if not line.strip():
print("响应header:", response.headers) continue
print("响应body:", response.text) resp = json.loads(line)
else: yield resp
print("请求失败,状态码:", response.status_code)
class BaiChuanWorker(ApiModelWorker): class BaiChuanWorker(ApiModelWorker):
BASE_URL = "https://api.baichuan-ai.com/v1/chat" BASE_URL = "https://api.baichuan-ai.com/v1/stream/chat"
SUPPORT_MODELS = ["Baichuan2-53B"] SUPPORT_MODELS = ["Baichuan2-53B"]
def __init__( def __init__(
@ -95,54 +94,34 @@ class BaiChuanWorker(ApiModelWorker):
self.secret_key = config.get("secret_key") self.secret_key = config.get("secret_key")
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
data = { super().generate_stream_gate(params)
"model": self.version,
"messages": [
{
"role": "user",
"content": params["prompt"]
}
],
"parameters": {
"temperature": params.get("temperature",TEMPERATURE),
"top_k": params.get("top_k",1)
}
}
json_data = json.dumps(data) messages = self.prompt_to_messages(params["prompt"])
time_stamp = int(time.time())
signature = calculate_md5(self.secret_key + json_data + str(time_stamp))
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(time_stamp),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",
}
response = requests.post(self.BASE_URL, data=json_data, headers=headers) text = ""
for resp in request_baichuan_api(messages=messages,
api_key=self.api_key,
secret_key=self.secret_key,
version=self.version,
temperature=params.get("temperature")):
if resp["code"] == 0:
text += resp["data"]["messages"][-1]["content"]
yield json.dumps(
{
"error_code": resp["code"],
"text": text
},
ensure_ascii=False
).encode() + b"\0"
else:
yield json.dumps(
{
"error_code": resp["code"],
"text": resp["msg"]
},
ensure_ascii=False
).encode() + b"\0"
if response.status_code == 200:
resp = eval(response.text)
yield json.dumps(
{
"error_code": resp["code"],
"text": resp["data"]["messages"][-1]["content"]
},
ensure_ascii=False
).encode() + b"\0"
else:
yield json.dumps(
{
"error_code": resp["code"],
"text": resp["msg"]
},
ensure_ascii=False
).encode() + b"\0"
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings
print("embedding") print("embedding")

View File

@ -0,0 +1,16 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.baichuan import request_baichuan_api
from pprint import pprint
def test_qwen():
messages = [{"role": "user", "content": "hello"}]
for x in request_baichuan_api(messages):
print(type(x))
pprint(x)
assert x["code"] == 0