【功能新增】在线 LLM 模型支持阿里云通义千问 (#1534)
* feat: add qwen-api * 使Qwen API支持temperature参数;添加测试用例 * 将online-api的sdk列为可选依赖 --------- Co-authored-by: liunux4odoo <liunux@qq.com>
This commit is contained in:
parent
b161985d79
commit
9bcce0a572
|
|
@ -133,6 +133,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
|
|||
- [MiniMax](https://api.minimax.chat)
|
||||
- [讯飞星火](https://xinghuo.xfyun.cn)
|
||||
- [百度千帆](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
|
||||
- [阿里云通义千问](https://dashscope.aliyun.com/)
|
||||
|
||||
项目中默认使用的 LLM 类型为 `THUDM/chatglm2-6b`,如需使用其他 LLM 类型,请在 [configs/model_config.py] 中对 `llm_model_dict` 和 `LLM_MODEL` 进行修改。
|
||||
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ ONLINE_LLM_MODEL = {
|
|||
"secret_key": "",
|
||||
"provider": "QianFanWorker",
|
||||
},
|
||||
# 火山方舟 API
|
||||
# 火山方舟 API,文档参考 https://www.volcengine.com/docs/82379
|
||||
"fangzhou-api": {
|
||||
"version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model", 更多的见文档模型支持列表中方舟部分。
|
||||
"version_url": "", # 可以不填写version,直接填写在方舟申请模型发布的API地址
|
||||
|
|
@ -123,6 +123,12 @@ ONLINE_LLM_MODEL = {
|
|||
"secret_key": "",
|
||||
"provider": "FangZhouWorker",
|
||||
},
|
||||
# 阿里云通义千问 API,文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
|
||||
"qwen-api": {
|
||||
"version": "qwen-turbo", # 可选包括 "qwen-turbo", "qwen-plus"
|
||||
"api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建
|
||||
"provider": "QwenWorker",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -79,6 +79,9 @@ FSCHAT_MODEL_WORKERS = {
|
|||
"fangzhou-api": {
|
||||
"port": 21005,
|
||||
},
|
||||
"qwen-api": {
|
||||
"port": 21006,
|
||||
},
|
||||
}
|
||||
|
||||
# fastchat multi model worker server
|
||||
|
|
|
|||
|
|
@ -24,6 +24,12 @@ pytest
|
|||
scikit-learn
|
||||
numexpr
|
||||
|
||||
# online api libs
|
||||
# zhipuai
|
||||
# dashscope>=1.10.0 # qwen
|
||||
# qianfan
|
||||
# volcengine>=1.0.106 # fangzhou
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
# psycopg2
|
||||
|
|
|
|||
|
|
@ -24,6 +24,12 @@ pytest
|
|||
scikit-learn
|
||||
numexpr
|
||||
|
||||
# online api libs
|
||||
# zhipuai
|
||||
# dashscope>=1.10.0 # qwen
|
||||
# qianfan
|
||||
# volcengine>=1.0.106 # fangzhou
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
# psycopg2
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ from .minimax import MiniMaxWorker
|
|||
from .xinghuo import XingHuoWorker
|
||||
from .qianfan import QianFanWorker
|
||||
from .fangzhou import FangZhouWorker
|
||||
from .qwen import QwenWorker
|
||||
|
|
|
|||
|
|
@ -92,5 +92,5 @@ class ApiModelWorker(BaseModelWorker):
|
|||
if content := msg[len(ai_start):].strip():
|
||||
result.append({"role": ai_role, "content": content})
|
||||
else:
|
||||
raise RuntimeError(f"unknow role in msg: {msg}")
|
||||
raise RuntimeError(f"unknown role in msg: {msg}")
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -0,0 +1,123 @@
|
|||
import json
|
||||
import sys
|
||||
from configs import TEMPERATURE
|
||||
from http import HTTPStatus
|
||||
from typing import List, Literal, Dict
|
||||
|
||||
from fastchat import conversation as conv
|
||||
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from server.utils import get_model_worker_config
|
||||
|
||||
|
||||
def request_qwen_api(
|
||||
messages: List[Dict[str, str]],
|
||||
api_key: str = None,
|
||||
version: str = "qwen-turbo",
|
||||
temperature: float = TEMPERATURE,
|
||||
model_name: str = "qwen-api",
|
||||
):
|
||||
import dashscope
|
||||
|
||||
config = get_model_worker_config(model_name)
|
||||
api_key = api_key or config.get("api_key")
|
||||
version = version or config.get("version")
|
||||
|
||||
gen = dashscope.Generation()
|
||||
responses = gen.call(
|
||||
model=version,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
result_format='message', # set the result is message format.
|
||||
stream=True,
|
||||
)
|
||||
|
||||
text = ""
|
||||
for resp in responses:
|
||||
if resp.status_code != HTTPStatus.OK:
|
||||
yield {
|
||||
"code": resp.status_code,
|
||||
"text": "api not response correctly",
|
||||
}
|
||||
|
||||
if resp["status_code"] == 200:
|
||||
if choices := resp["output"]["choices"]:
|
||||
yield {
|
||||
"code": 200,
|
||||
"text": choices[0]["message"]["content"],
|
||||
}
|
||||
else:
|
||||
yield {
|
||||
"code": resp["status_code"],
|
||||
"text": resp["message"],
|
||||
}
|
||||
|
||||
|
||||
class QwenWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
|
||||
model_names: List[str] = ["qwen-api"],
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# TODO: 确认模板是否需要修改
|
||||
self.conv = conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||
messages=[],
|
||||
roles=["user", "assistant", "system"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
config = self.get_config()
|
||||
self.api_key = config.get("api_key")
|
||||
self.version = version
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
messages = self.prompt_to_messages(params["prompt"])
|
||||
|
||||
for resp in request_qwen_api(messages=messages,
|
||||
api_key=self.api_key,
|
||||
version=self.version,
|
||||
temperature=params.get("temperature")):
|
||||
if resp["code"] == 200:
|
||||
yield json.dumps({
|
||||
"error_code": 0,
|
||||
"text": resp["text"]
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
else:
|
||||
yield json.dumps({
|
||||
"error_code": resp["code"],
|
||||
"text": resp["text"]
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = QwenWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:20007",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=20007)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from server.model_workers.qwen import request_qwen_api
|
||||
from pprint import pprint
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", ["qwen-turbo"])
|
||||
def test_qwen(version):
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
print("\n" + version + "\n")
|
||||
|
||||
for x in request_qwen_api(messages, version=version):
|
||||
print(type(x))
|
||||
pprint(x)
|
||||
assert x["code"] == 200
|
||||
Loading…
Reference in New Issue