diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py index a3a162f..611d94e 100644 --- a/server/model_workers/__init__.py +++ b/server/model_workers/__init__.py @@ -2,3 +2,4 @@ from .zhipu import ChatGLMWorker from .minimax import MiniMaxWorker from .xinghuo import XingHuoWorker from .qianfan import QianFanWorker +from .fangzhou import FangZhouWorker diff --git a/server/model_workers/fangzhou.py b/server/model_workers/fangzhou.py index 08820de..5207fdb 100644 --- a/server/model_workers/fangzhou.py +++ b/server/model_workers/fangzhou.py @@ -3,9 +3,52 @@ from configs.model_config import TEMPERATURE from fastchat import conversation as conv import sys import json +from pprint import pprint +from server.utils import get_model_worker_config from typing import List, Literal, Dict +def request_volc_api( + messages: List[Dict], + model_name: str = "fangzhou-api", + version: str = "chatglm-6b-model", + temperature: float = TEMPERATURE, + api_key: str = None, + secret_key: str = None, +): + from volcengine.maas import MaasService, MaasException, ChatRole + + maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') + config = get_model_worker_config(model_name) + version = version or config.get("version") + version_url = config.get("version_url") + api_key = api_key or config.get("api_key") + secret_key = secret_key or config.get("secret_key") + + maas.set_ak(api_key) + maas.set_sk(secret_key) + + # document: "https://www.volcengine.com/docs/82379/1099475" + req = { + "model": { + "name": version, + }, + "parameters": { + # 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明 + "max_new_tokens": 1000, + "temperature": temperature, + }, + "messages": messages, + } + + try: + resps = maas.stream_chat(req) + for resp in resps: + yield resp + except MaasException as e: + print(e) + + class FangZhouWorker(ApiModelWorker): """ 火山方舟 @@ -21,7 +64,7 @@ class FangZhouWorker(ApiModelWorker): **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 16384) + kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小 super().__init__(**kwargs) config = self.get_config() @@ -41,40 +84,27 @@ class FangZhouWorker(ApiModelWorker): def generate_stream_gate(self, params): super().generate_stream_gate(params) - from volcengine.maas import MaasService, MaasException, ChatRole - maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') - maas.set_ak(self.api_key) - maas.set_sk(self.secret_key) - - # document: "https://www.volcengine.com/docs/82379/1099475" - req = { - "model": { - "name": self.version, - }, - "parameters": { - # 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明 - # "max_new_tokens": 1000, - "temperature": params.get("temperature", TEMPERATURE), - }, - "messages": [{"role": ChatRole.USER, "content": params["prompt"]}] - } + messages = self.prompt_to_messages(params["prompt"]) text = "" - try: - resps = maas.stream_chat(req) - for resp in resps: - text += resp.choice.message.content - yield json.dumps({"error_code": 0, "text": text}, - ensure_ascii=False).encode() + b"\0" - except MaasException as e: - print(e) + for resp in request_volc_api(messages=messages, + model_name=self.model_names[0], + version=self.version, + temperature=params.get("temperature", TEMPERATURE), + ): + error = resp.error + if error.code_n > 0: + data = {"error_code": error.code_n, "text": error.message} + elif chunk := resp.choice.message.content: + text += chunk + data = {"error_code": 0, "text": text} + yield json.dumps(data, ensure_ascii=False).encode() + b"\0" - -def get_embeddings(self, params): - # TODO: 支持embeddings - print("embedding") - print(params) + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) if __name__ == "__main__": @@ -84,8 +114,8 @@ if __name__ == "__main__": worker = FangZhouWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20006", + worker_addr="http://127.0.0.1:21005", ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20006) + uvicorn.run(app, port=21005) diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index c772c0d..39ff293 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -93,8 +93,8 @@ if __name__ == "__main__": worker = MiniMaxWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20004", + worker_addr="http://127.0.0.1:21002", ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20003) + uvicorn.run(app, port=21002) diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index f7a5161..387d4b7 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -168,8 +168,8 @@ if __name__ == "__main__": worker = QianFanWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20006", + worker_addr="http://127.0.0.1:21004" ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20006) \ No newline at end of file + uvicorn.run(app, port=21004) \ No newline at end of file diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py index 499e8bc..bc98a9c 100644 --- a/server/model_workers/xinghuo.py +++ b/server/model_workers/xinghuo.py @@ -94,8 +94,8 @@ if __name__ == "__main__": worker = XingHuoWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20005", + worker_addr="http://127.0.0.1:21003", ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20005) + uvicorn.run(app, port=21003) diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index f835ac0..18cec5b 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -67,8 +67,8 @@ if __name__ == "__main__": worker = ChatGLMWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20003", + worker_addr="http://127.0.0.1:21001", ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20003) + uvicorn.run(app, port=21001) diff --git a/tests/online_api/test_fangzhou.py b/tests/online_api/test_fangzhou.py new file mode 100644 index 0000000..1157537 --- /dev/null +++ b/tests/online_api/test_fangzhou.py @@ -0,0 +1,22 @@ +import sys +from pathlib import Path +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) + +from server.model_workers.fangzhou import request_volc_api +from pprint import pprint +import pytest + + +@pytest.mark.parametrize("version", ["chatglm-6b-model"]) +def test_qianfan(version): + messages = [{"role": "user", "content": "hello"}] + print("\n" + version + "\n") + i = 1 + for x in request_volc_api(messages, version=version): + print(type(x)) + pprint(x) + if chunk := x.choice.message.content: + print(chunk) + assert x.choice.message + i += 1 diff --git a/tests/online_api/test_qianfan.py b/tests/online_api/test_qianfan.py index 0e8a948..b4b9b15 100644 --- a/tests/online_api/test_qianfan.py +++ b/tests/online_api/test_qianfan.py @@ -8,7 +8,7 @@ from pprint import pprint import pytest -@pytest.mark.parametrize("version", MODEL_VERSIONS.keys()) +@pytest.mark.parametrize("version", list(MODEL_VERSIONS.keys())[:2]) def test_qianfan(version): messages = [{"role": "user", "content": "你好"}] print("\n" + version + "\n")