diff --git a/README.md b/README.md index 869fd72..e1fd464 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch - [MiniMax](https://api.minimax.chat) - [讯飞星火](https://xinghuo.xfyun.cn) - [百度千帆](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan) +- [阿里云通义千问](https://dashscope.aliyun.com/) 项目中默认使用的 LLM 类型为 `THUDM/chatglm2-6b`,如需使用其他 LLM 类型,请在 [configs/model_config.py] 中对 `llm_model_dict` 和 `LLM_MODEL` 进行修改。 diff --git a/configs/model_config.py.example b/configs/model_config.py.example index c6f1dc4..5fb0691 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -115,7 +115,7 @@ ONLINE_LLM_MODEL = { "secret_key": "", "provider": "QianFanWorker", }, - # 火山方舟 API + # 火山方舟 API,文档参考 https://www.volcengine.com/docs/82379 "fangzhou-api": { "version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model", 更多的见文档模型支持列表中方舟部分。 "version_url": "", # 可以不填写version,直接填写在方舟申请模型发布的API地址 @@ -123,6 +123,12 @@ ONLINE_LLM_MODEL = { "secret_key": "", "provider": "FangZhouWorker", }, + # 阿里云通义千问 API,文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details + "qwen-api": { + "version": "qwen-turbo", # 可选包括 "qwen-turbo", "qwen-plus" + "api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建 + "provider": "QwenWorker", + }, } diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 7a4e318..11fc9fd 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -79,6 +79,9 @@ FSCHAT_MODEL_WORKERS = { "fangzhou-api": { "port": 21005, }, + "qwen-api": { + "port": 21006, + }, } # fastchat multi model worker server diff --git a/requirements.txt b/requirements.txt index 5699d46..fe09f5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,12 @@ pytest scikit-learn numexpr +# online api libs +# zhipuai +# dashscope>=1.10.0 # qwen +# qianfan +# volcengine>=1.0.106 # fangzhou + # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 # psycopg2 diff --git a/requirements_api.txt b/requirements_api.txt index c56c07b..5ba0510 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -24,6 +24,12 @@ pytest scikit-learn numexpr +# online api libs +# zhipuai +# dashscope>=1.10.0 # qwen +# qianfan +# volcengine>=1.0.106 # fangzhou + # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 # psycopg2 diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py index 611d94e..c1a824b 100644 --- a/server/model_workers/__init__.py +++ b/server/model_workers/__init__.py @@ -3,3 +3,4 @@ from .minimax import MiniMaxWorker from .xinghuo import XingHuoWorker from .qianfan import QianFanWorker from .fangzhou import FangZhouWorker +from .qwen import QwenWorker diff --git a/server/model_workers/base.py b/server/model_workers/base.py index 653e29b..515c5db 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -92,5 +92,5 @@ class ApiModelWorker(BaseModelWorker): if content := msg[len(ai_start):].strip(): result.append({"role": ai_role, "content": content}) else: - raise RuntimeError(f"unknow role in msg: {msg}") + raise RuntimeError(f"unknown role in msg: {msg}") return result diff --git a/server/model_workers/qwen.py b/server/model_workers/qwen.py new file mode 100644 index 0000000..32d8757 --- /dev/null +++ b/server/model_workers/qwen.py @@ -0,0 +1,123 @@ +import json +import sys +from configs import TEMPERATURE +from http import HTTPStatus +from typing import List, Literal, Dict + +from fastchat import conversation as conv + +from server.model_workers.base import ApiModelWorker +from server.utils import get_model_worker_config + + +def request_qwen_api( + messages: List[Dict[str, str]], + api_key: str = None, + version: str = "qwen-turbo", + temperature: float = TEMPERATURE, + model_name: str = "qwen-api", +): + import dashscope + + config = get_model_worker_config(model_name) + api_key = api_key or config.get("api_key") + version = version or config.get("version") + + gen = dashscope.Generation() + responses = gen.call( + model=version, + temperature=temperature, + api_key=api_key, + messages=messages, + result_format='message', # set the result is message format. + stream=True, + ) + + text = "" + for resp in responses: + if resp.status_code != HTTPStatus.OK: + yield { + "code": resp.status_code, + "text": "api not response correctly", + } + + if resp["status_code"] == 200: + if choices := resp["output"]["choices"]: + yield { + "code": 200, + "text": choices[0]["message"]["content"], + } + else: + yield { + "code": resp["status_code"], + "text": resp["message"], + } + + +class QwenWorker(ApiModelWorker): + def __init__( + self, + *, + version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo", + model_names: List[str] = ["qwen-api"], + controller_addr: str, + worker_addr: str, + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 16384) + super().__init__(**kwargs) + + # TODO: 确认模板是否需要修改 + self.conv = conv.Conversation( + name=self.model_names[0], + system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", + messages=[], + roles=["user", "assistant", "system"], + sep="\n### ", + stop_str="###", + ) + config = self.get_config() + self.api_key = config.get("api_key") + self.version = version + + def generate_stream_gate(self, params): + messages = self.prompt_to_messages(params["prompt"]) + + for resp in request_qwen_api(messages=messages, + api_key=self.api_key, + version=self.version, + temperature=params.get("temperature")): + if resp["code"] == 200: + yield json.dumps({ + "error_code": 0, + "text": resp["text"] + }, + ensure_ascii=False + ).encode() + b"\0" + else: + yield json.dumps({ + "error_code": resp["code"], + "text": resp["text"] + }, + ensure_ascii=False + ).encode() + b"\0" + + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.model_worker import app + + worker = QwenWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:20007", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=20007) diff --git a/tests/online_api/test_qwen.py b/tests/online_api/test_qwen.py new file mode 100644 index 0000000..001cf60 --- /dev/null +++ b/tests/online_api/test_qwen.py @@ -0,0 +1,19 @@ +import sys +from pathlib import Path +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) + +from server.model_workers.qwen import request_qwen_api +from pprint import pprint +import pytest + + +@pytest.mark.parametrize("version", ["qwen-turbo"]) +def test_qwen(version): + messages = [{"role": "user", "content": "hello"}] + print("\n" + version + "\n") + + for x in request_qwen_api(messages, version=version): + print(type(x)) + pprint(x) + assert x["code"] == 200