diff --git a/configs/model_config.py.example b/configs/model_config.py.example index dff60f7..c6f1dc4 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -115,6 +115,14 @@ ONLINE_LLM_MODEL = { "secret_key": "", "provider": "QianFanWorker", }, + # 火山方舟 API + "fangzhou-api": { + "version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model", 更多的见文档模型支持列表中方舟部分。 + "version_url": "", # 可以不填写version,直接填写在方舟申请模型发布的API地址 + "api_key": "", + "secret_key": "", + "provider": "FangZhouWorker", + }, } diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py index a3a162f..611d94e 100644 --- a/server/model_workers/__init__.py +++ b/server/model_workers/__init__.py @@ -2,3 +2,4 @@ from .zhipu import ChatGLMWorker from .minimax import MiniMaxWorker from .xinghuo import XingHuoWorker from .qianfan import QianFanWorker +from .fangzhou import FangZhouWorker diff --git a/server/model_workers/fangzhou.py b/server/model_workers/fangzhou.py new file mode 100644 index 0000000..5207fdb --- /dev/null +++ b/server/model_workers/fangzhou.py @@ -0,0 +1,121 @@ +from server.model_workers.base import ApiModelWorker +from configs.model_config import TEMPERATURE +from fastchat import conversation as conv +import sys +import json +from pprint import pprint +from server.utils import get_model_worker_config +from typing import List, Literal, Dict + + +def request_volc_api( + messages: List[Dict], + model_name: str = "fangzhou-api", + version: str = "chatglm-6b-model", + temperature: float = TEMPERATURE, + api_key: str = None, + secret_key: str = None, +): + from volcengine.maas import MaasService, MaasException, ChatRole + + maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') + config = get_model_worker_config(model_name) + version = version or config.get("version") + version_url = config.get("version_url") + api_key = api_key or config.get("api_key") + secret_key = secret_key or config.get("secret_key") + + maas.set_ak(api_key) + maas.set_sk(secret_key) + + # document: "https://www.volcengine.com/docs/82379/1099475" + req = { + "model": { + "name": version, + }, + "parameters": { + # 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明 + "max_new_tokens": 1000, + "temperature": temperature, + }, + "messages": messages, + } + + try: + resps = maas.stream_chat(req) + for resp in resps: + yield resp + except MaasException as e: + print(e) + + +class FangZhouWorker(ApiModelWorker): + """ + 火山方舟 + """ + SUPPORT_MODELS = ["chatglm-6b-model"] + def __init__( + self, + *, + version: Literal["chatglm-6b-model"] = "chatglm-6b-model", + model_names: List[str] = ["fangzhou-api"], + controller_addr: str, + worker_addr: str, + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小 + super().__init__(**kwargs) + + config = self.get_config() + self.version = version + self.api_key = config.get("api_key") + self.secret_key = config.get("secret_key") + + from volcengine.maas import ChatRole + self.conv = conv.Conversation( + name=self.model_names[0], + system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", + messages=[], + roles=[ChatRole.USER, ChatRole.ASSISTANT, ChatRole.SYSTEM], + sep="\n### ", + stop_str="###", + ) + + def generate_stream_gate(self, params): + super().generate_stream_gate(params) + + messages = self.prompt_to_messages(params["prompt"]) + text = "" + + for resp in request_volc_api(messages=messages, + model_name=self.model_names[0], + version=self.version, + temperature=params.get("temperature", TEMPERATURE), + ): + error = resp.error + if error.code_n > 0: + data = {"error_code": error.code_n, "text": error.message} + elif chunk := resp.choice.message.content: + text += chunk + data = {"error_code": 0, "text": text} + yield json.dumps(data, ensure_ascii=False).encode() + b"\0" + + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.model_worker import app + + worker = FangZhouWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:21005", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=21005) diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index c772c0d..39ff293 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -93,8 +93,8 @@ if __name__ == "__main__": worker = MiniMaxWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20004", + worker_addr="http://127.0.0.1:21002", ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20003) + uvicorn.run(app, port=21002) diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index f7a5161..387d4b7 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -168,8 +168,8 @@ if __name__ == "__main__": worker = QianFanWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20006", + worker_addr="http://127.0.0.1:21004" ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20006) \ No newline at end of file + uvicorn.run(app, port=21004) \ No newline at end of file diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py index 499e8bc..bc98a9c 100644 --- a/server/model_workers/xinghuo.py +++ b/server/model_workers/xinghuo.py @@ -94,8 +94,8 @@ if __name__ == "__main__": worker = XingHuoWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20005", + worker_addr="http://127.0.0.1:21003", ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20005) + uvicorn.run(app, port=21003) diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index f835ac0..18cec5b 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -67,8 +67,8 @@ if __name__ == "__main__": worker = ChatGLMWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20003", + worker_addr="http://127.0.0.1:21001", ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20003) + uvicorn.run(app, port=21001) diff --git a/tests/online_api/test_fangzhou.py b/tests/online_api/test_fangzhou.py new file mode 100644 index 0000000..1157537 --- /dev/null +++ b/tests/online_api/test_fangzhou.py @@ -0,0 +1,22 @@ +import sys +from pathlib import Path +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) + +from server.model_workers.fangzhou import request_volc_api +from pprint import pprint +import pytest + + +@pytest.mark.parametrize("version", ["chatglm-6b-model"]) +def test_qianfan(version): + messages = [{"role": "user", "content": "hello"}] + print("\n" + version + "\n") + i = 1 + for x in request_volc_api(messages, version=version): + print(type(x)) + pprint(x) + if chunk := x.choice.message.content: + print(chunk) + assert x.choice.message + i += 1 diff --git a/tests/online_api/test_qianfan.py b/tests/online_api/test_qianfan.py index 0e8a948..b4b9b15 100644 --- a/tests/online_api/test_qianfan.py +++ b/tests/online_api/test_qianfan.py @@ -8,7 +8,7 @@ from pprint import pprint import pytest -@pytest.mark.parametrize("version", MODEL_VERSIONS.keys()) +@pytest.mark.parametrize("version", list(MODEL_VERSIONS.keys())[:2]) def test_qianfan(version): messages = [{"role": "user", "content": "你好"}] print("\n" + version + "\n")