【功能新增】在线 LLM 模型支持阿里云通义千问 (#1534)

* feat: add qwen-api

* 使Qwen API支持temperature参数;添加测试用例

* 将online-api的sdk列为可选依赖

---------

Co-authored-by: liunux4odoo <liunux@qq.com>
This commit is contained in:
Leego 2023-09-20 21:34:12 +08:00 committed by GitHub
parent b161985d79
commit 9bcce0a572
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 167 additions and 2 deletions

View File

@ -133,6 +133,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [MiniMax](https://api.minimax.chat)
- [讯飞星火](https://xinghuo.xfyun.cn)
- [百度千帆](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
- [阿里云通义千问](https://dashscope.aliyun.com/)
项目中默认使用的 LLM 类型为 `THUDM/chatglm2-6b`,如需使用其他 LLM 类型,请在 [configs/model_config.py] 中对 `llm_model_dict``LLM_MODEL` 进行修改。

View File

@ -115,7 +115,7 @@ ONLINE_LLM_MODEL = {
"secret_key": "",
"provider": "QianFanWorker",
},
# 火山方舟 API
# 火山方舟 API,文档参考 https://www.volcengine.com/docs/82379
"fangzhou-api": {
"version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model" 更多的见文档模型支持列表中方舟部分。
"version_url": "", # 可以不填写version直接填写在方舟申请模型发布的API地址
@ -123,6 +123,12 @@ ONLINE_LLM_MODEL = {
"secret_key": "",
"provider": "FangZhouWorker",
},
# 阿里云通义千问 API文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
"qwen-api": {
"version": "qwen-turbo", # 可选包括 "qwen-turbo", "qwen-plus"
"api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建
"provider": "QwenWorker",
},
}

View File

@ -79,6 +79,9 @@ FSCHAT_MODEL_WORKERS = {
"fangzhou-api": {
"port": 21005,
},
"qwen-api": {
"port": 21006,
},
}
# fastchat multi model worker server

View File

@ -24,6 +24,12 @@ pytest
scikit-learn
numexpr
# online api libs
# zhipuai
# dashscope>=1.10.0 # qwen
# qianfan
# volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
# psycopg2

View File

@ -24,6 +24,12 @@ pytest
scikit-learn
numexpr
# online api libs
# zhipuai
# dashscope>=1.10.0 # qwen
# qianfan
# volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
# psycopg2

View File

@ -3,3 +3,4 @@ from .minimax import MiniMaxWorker
from .xinghuo import XingHuoWorker
from .qianfan import QianFanWorker
from .fangzhou import FangZhouWorker
from .qwen import QwenWorker

View File

@ -92,5 +92,5 @@ class ApiModelWorker(BaseModelWorker):
if content := msg[len(ai_start):].strip():
result.append({"role": ai_role, "content": content})
else:
raise RuntimeError(f"unknow role in msg: {msg}")
raise RuntimeError(f"unknown role in msg: {msg}")
return result

View File

@ -0,0 +1,123 @@
import json
import sys
from configs import TEMPERATURE
from http import HTTPStatus
from typing import List, Literal, Dict
from fastchat import conversation as conv
from server.model_workers.base import ApiModelWorker
from server.utils import get_model_worker_config
def request_qwen_api(
messages: List[Dict[str, str]],
api_key: str = None,
version: str = "qwen-turbo",
temperature: float = TEMPERATURE,
model_name: str = "qwen-api",
):
import dashscope
config = get_model_worker_config(model_name)
api_key = api_key or config.get("api_key")
version = version or config.get("version")
gen = dashscope.Generation()
responses = gen.call(
model=version,
temperature=temperature,
api_key=api_key,
messages=messages,
result_format='message', # set the result is message format.
stream=True,
)
text = ""
for resp in responses:
if resp.status_code != HTTPStatus.OK:
yield {
"code": resp.status_code,
"text": "api not response correctly",
}
if resp["status_code"] == 200:
if choices := resp["output"]["choices"]:
yield {
"code": 200,
"text": choices[0]["message"]["content"],
}
else:
yield {
"code": resp["status_code"],
"text": resp["message"],
}
class QwenWorker(ApiModelWorker):
def __init__(
self,
*,
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
model_names: List[str] = ["qwen-api"],
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
config = self.get_config()
self.api_key = config.get("api_key")
self.version = version
def generate_stream_gate(self, params):
messages = self.prompt_to_messages(params["prompt"])
for resp in request_qwen_api(messages=messages,
api_key=self.api_key,
version=self.version,
temperature=params.get("temperature")):
if resp["code"] == 200:
yield json.dumps({
"error_code": 0,
"text": resp["text"]
},
ensure_ascii=False
).encode() + b"\0"
else:
yield json.dumps({
"error_code": resp["code"],
"text": resp["text"]
},
ensure_ascii=False
).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = QwenWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20007",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20007)

View File

@ -0,0 +1,19 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.qwen import request_qwen_api
from pprint import pprint
import pytest
@pytest.mark.parametrize("version", ["qwen-turbo"])
def test_qwen(version):
messages = [{"role": "user", "content": "hello"}]
print("\n" + version + "\n")
for x in request_qwen_api(messages, version=version):
print(type(x))
pprint(x)
assert x["code"] == 200