From 512638a3b18e2f258f00b7b974a8519ef2b11d55 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Wed, 13 Sep 2023 13:51:05 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=AE=AF=E9=A3=9E=E6=98=9F?= =?UTF-8?q?=E7=81=AB=E5=A4=A7=E6=A8=A1=E5=9E=8B=E5=9C=A8=E7=BA=BFAPI=20(#1?= =?UTF-8?q?460)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 支持讯飞星火大模型在线API --- configs/model_config.py.example | 9 +- configs/server_config.py.example | 3 + server/model_workers/SparkApi.py | 137 +++++++++++++++++++++++++++++++ server/model_workers/__init__.py | 1 + server/model_workers/base.py | 20 +++++ server/model_workers/minimax.py | 16 +--- server/model_workers/xinghuo.py | 103 +++++++++++++++++++++++ 7 files changed, 276 insertions(+), 13 deletions(-) create mode 100644 server/model_workers/SparkApi.py create mode 100644 server/model_workers/xinghuo.py diff --git a/configs/model_config.py.example b/configs/model_config.py.example index a552441..977c25f 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -95,7 +95,14 @@ llm_model_dict = { "is_pro": False, "provider": "MiniMaxWorker", }, - +"xinghuo-api": { + "api_base_url": "http://127.0.0.1:8888/v1", + "APPID": "", + "APISecret": "", + "api_key": "", + "is_v2": False, + "provider": "XingHuoWorker", + } } # LLM 名称 diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 0fdd492..bc21d9a 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -70,6 +70,9 @@ FSCHAT_MODEL_WORKERS = { "minimax-api": { # 请为每个在线API设置不同的端口 "port": 20004, }, + "xinghuo-api": { # 请为每个在线API设置不同的端口 + "port": 20005, + }, } # fastchat multi model worker server diff --git a/server/model_workers/SparkApi.py b/server/model_workers/SparkApi.py new file mode 100644 index 0000000..3b4f096 --- /dev/null +++ b/server/model_workers/SparkApi.py @@ -0,0 +1,137 @@ +import _thread as thread +import base64 +import datetime +import hashlib +import hmac +import json +from urllib.parse import urlparse +import ssl +from datetime import datetime +from time import mktime +from urllib.parse import urlencode +from wsgiref.handlers import format_date_time + +import websocket # 使用websocket_client +answer = "" + +class Ws_Param(object): + # 初始化 + def __init__(self, APPID, APIKey, APISecret, Spark_url): + self.APPID = APPID + self.APIKey = APIKey + self.APISecret = APISecret + self.host = urlparse(Spark_url).netloc + self.path = urlparse(Spark_url).path + self.Spark_url = Spark_url + + # 生成url + def create_url(self): + # 生成RFC1123格式的时间戳 + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + + # 拼接字符串 + signature_origin = "host: " + self.host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + self.path + " HTTP/1.1" + + # 进行hmac-sha256进行加密 + signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' + + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + + # 将请求的鉴权参数组合为字典 + v = { + "authorization": authorization, + "date": date, + "host": self.host + } + # 拼接鉴权参数,生成url + url = self.Spark_url + '?' + urlencode(v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 + return url + + +# 收到websocket错误的处理 +def on_error(ws, error): + print("### error:", error) + + +# 收到websocket关闭的处理 +def on_close(ws,one,two): + print(" ") + + +# 收到websocket连接建立的处理 +def on_open(ws): + thread.start_new_thread(run, (ws,)) + + +def run(ws, *args): + data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature=ws.temperature)) + ws.send(data) + + +# 收到websocket消息的处理 +def on_message(ws, message): + # print(message) + data = json.loads(message) + code = data['header']['code'] + if code != 0: + print(f'请求错误: {code}, {data}') + ws.close() + else: + choices = data["payload"]["choices"] + status = choices["status"] + content = choices["text"][0]["content"] + print(content,end ="") + global answer + answer += content + # print(1) + if status == 2: + ws.close() + + +def gen_params(appid, domain,question, temperature): + """ + 通过appid和用户的提问来生成请参数 + """ + data = { + "header": { + "app_id": appid, + "uid": "1234" + }, + "parameter": { + "chat": { + "domain": domain, + "random_threshold": 0.5, + "max_tokens": 2048, + "auditing": "default", + "temperature": temperature, + } + }, + "payload": { + "message": { + "text": question + } + } + } + return data + + +def main(appid, api_key, api_secret, Spark_url,domain, question, temperature): + # print("星火:") + wsParam = Ws_Param(appid, api_key, api_secret, Spark_url) + websocket.enableTrace(False) + wsUrl = wsParam.create_url() + ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open) + ws.appid = appid + ws.question = question + ws.domain = domain + ws.temperature = temperature + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py index 4ec62dd..f904ec8 100644 --- a/server/model_workers/__init__.py +++ b/server/model_workers/__init__.py @@ -1,2 +1,3 @@ from .zhipu import ChatGLMWorker from .minimax import MiniMaxWorker +from .xinghuo import XingHuoWorker diff --git a/server/model_workers/base.py b/server/model_workers/base.py index 090cf72..df5fbfc 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -74,3 +74,23 @@ class ApiModelWorker(BaseModelWorker): def get_config(self): from server.utils import get_model_worker_config return get_model_worker_config(self.model_names[0]) + + def prompt_to_messages(self, prompt: str) -> List[Dict]: + ''' + 将prompt字符串拆分成messages. + ''' + result = [] + user_role = self.conv.roles[0] + ai_role = self.conv.roles[1] + user_start = user_role + ":" + ai_start = ai_role + ":" + for msg in prompt.split(self.conv.sep)[1:-1]: + if msg.startswith(user_start): + if content := msg[len(user_start):].strip(): + result.append({"role": user_role, "content": content}) + elif msg.startswith(ai_start): + if content := msg[len(ai_start):].strip(): + result.append({"role": ai_role, "content": content}) + else: + raise RuntimeError(f"unknow role in msg: {msg}") + return result diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index 895c85a..c772c0d 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -33,21 +33,13 @@ class MiniMaxWorker(ApiModelWorker): ) def prompt_to_messages(self, prompt: str) -> List[Dict]: - result = [] - user_start = self.conv.roles[0] + ":" - bot_start = self.conv.roles[1] + ":" - for msg in prompt.split(self.conv.sep)[1:-1]: - if msg.startswith(user_start): - result.append({"sender_type": "USER", "text": msg[len(user_start):].strip()}) - elif msg.startswith(bot_start): - result.append({"sender_type": "BOT", "text": msg[len(bot_start):].strip()}) - else: - raise RuntimeError(f"unknow role in msg: {msg}") - return result + result = super().prompt_to_messages(prompt) + messages = [{"sender_type": x["role"], "text": x["content"]} for x in result] + return messages def generate_stream_gate(self, params): # 按照官网推荐,直接调用abab 5.5模型 - # TODO: 支持历史消息,支持指定回复要求,支持指定用户名称、AI名称 + # TODO: 支持指定回复要求,支持指定用户名称、AI名称 super().generate_stream_gate(params) config = self.get_config() diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py new file mode 100644 index 0000000..20ffe74 --- /dev/null +++ b/server/model_workers/xinghuo.py @@ -0,0 +1,103 @@ +from server.model_workers.base import ApiModelWorker +from fastchat import conversation as conv +import sys +import json +import httpx +from pprint import pprint +import SparkApi +import websockets +from server.utils import iter_over_async, asyncio +from typing import List, Dict + + +async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature): + # print("星火:") + wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url) + wsUrl = wsParam.create_url() + data = SparkApi.gen_params(appid, domain, question, temperature) + async with websockets.connect(wsUrl) as ws: + await ws.send(json.dumps(data, ensure_ascii=False)) + finish = False + while not finish: + chunk = await ws.recv() + response = json.loads(chunk) + if response.get("header", {}).get("status") == 2: + finish = True + if text := response.get("payload", {}).get("choices", {}).get("text"): + yield text[0]["content"] + + +class XingHuoWorker(ApiModelWorker): + def __init__( + self, + *, + model_names: List[str] = ["xinghuo-api"], + controller_addr: str, + worker_addr: str, + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 8192) + super().__init__(**kwargs) + + # TODO: 确认模板是否需要修改 + self.conv = conv.Conversation( + name=self.model_names[0], + system_message="", + messages=[], + roles=["user", "assistant"], + sep="\n### ", + stop_str="###", + ) + + def generate_stream_gate(self, params): + # TODO: 当前每次对话都要重新连接websocket,确认是否可以保持连接 + + super().generate_stream_gate(params) + config = self.get_config() + appid = config.get("APPID") + api_secret = config.get("APISecret") + api_key = config.get("api_key") + + if config.get("is_v2"): + domain = "generalv2" # v2.0版本 + Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址 + else: + domain = "general" # v1.5版本 + Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址 + + question = self.prompt_to_messages(params["prompt"]) + text = "" + + try: + loop = asyncio.get_event_loop() + except: + loop = asyncio.new_event_loop() + + for chunk in iter_over_async( + request(appid, api_key, api_secret, Spark_url, domain, question, params.get("temperature")), + loop=loop, + ): + if chunk: + print(chunk) + text += chunk + yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0" + + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.model_worker import app + + worker = XingHuoWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:20005", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=20005)