Merge branch 'pr1501' into dev

This commit is contained in:
liunux4odoo 2023-09-17 00:30:24 +08:00
commit 7577bd52ce
9 changed files with 161 additions and 9 deletions

View File

@ -115,6 +115,14 @@ ONLINE_LLM_MODEL = {
"secret_key": "", "secret_key": "",
"provider": "QianFanWorker", "provider": "QianFanWorker",
}, },
# 火山方舟 API
"fangzhou-api": {
"version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model" 更多的见文档模型支持列表中方舟部分。
"version_url": "", # 可以不填写version直接填写在方舟申请模型发布的API地址
"api_key": "",
"secret_key": "",
"provider": "FangZhouWorker",
},
} }

View File

@ -2,3 +2,4 @@ from .zhipu import ChatGLMWorker
from .minimax import MiniMaxWorker from .minimax import MiniMaxWorker
from .xinghuo import XingHuoWorker from .xinghuo import XingHuoWorker
from .qianfan import QianFanWorker from .qianfan import QianFanWorker
from .fangzhou import FangZhouWorker

View File

@ -0,0 +1,121 @@
from server.model_workers.base import ApiModelWorker
from configs.model_config import TEMPERATURE
from fastchat import conversation as conv
import sys
import json
from pprint import pprint
from server.utils import get_model_worker_config
from typing import List, Literal, Dict
def request_volc_api(
messages: List[Dict],
model_name: str = "fangzhou-api",
version: str = "chatglm-6b-model",
temperature: float = TEMPERATURE,
api_key: str = None,
secret_key: str = None,
):
from volcengine.maas import MaasService, MaasException, ChatRole
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
config = get_model_worker_config(model_name)
version = version or config.get("version")
version_url = config.get("version_url")
api_key = api_key or config.get("api_key")
secret_key = secret_key or config.get("secret_key")
maas.set_ak(api_key)
maas.set_sk(secret_key)
# document: "https://www.volcengine.com/docs/82379/1099475"
req = {
"model": {
"name": version,
},
"parameters": {
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
"max_new_tokens": 1000,
"temperature": temperature,
},
"messages": messages,
}
try:
resps = maas.stream_chat(req)
for resp in resps:
yield resp
except MaasException as e:
print(e)
class FangZhouWorker(ApiModelWorker):
"""
火山方舟
"""
SUPPORT_MODELS = ["chatglm-6b-model"]
def __init__(
self,
*,
version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
model_names: List[str] = ["fangzhou-api"],
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小
super().__init__(**kwargs)
config = self.get_config()
self.version = version
self.api_key = config.get("api_key")
self.secret_key = config.get("secret_key")
from volcengine.maas import ChatRole
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=[ChatRole.USER, ChatRole.ASSISTANT, ChatRole.SYSTEM],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params):
super().generate_stream_gate(params)
messages = self.prompt_to_messages(params["prompt"])
text = ""
for resp in request_volc_api(messages=messages,
model_name=self.model_names[0],
version=self.version,
temperature=params.get("temperature", TEMPERATURE),
):
error = resp.error
if error.code_n > 0:
data = {"error_code": error.code_n, "text": error.message}
elif chunk := resp.choice.message.content:
text += chunk
data = {"error_code": 0, "text": text}
yield json.dumps(data, ensure_ascii=False).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = FangZhouWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21005",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21005)

View File

@ -93,8 +93,8 @@ if __name__ == "__main__":
worker = MiniMaxWorker( worker = MiniMaxWorker(
controller_addr="http://127.0.0.1:20001", controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20004", worker_addr="http://127.0.0.1:21002",
) )
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
uvicorn.run(app, port=20003) uvicorn.run(app, port=21002)

View File

@ -168,8 +168,8 @@ if __name__ == "__main__":
worker = QianFanWorker( worker = QianFanWorker(
controller_addr="http://127.0.0.1:20001", controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20006", worker_addr="http://127.0.0.1:21004"
) )
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
uvicorn.run(app, port=20006) uvicorn.run(app, port=21004)

View File

@ -94,8 +94,8 @@ if __name__ == "__main__":
worker = XingHuoWorker( worker = XingHuoWorker(
controller_addr="http://127.0.0.1:20001", controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20005", worker_addr="http://127.0.0.1:21003",
) )
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
uvicorn.run(app, port=20005) uvicorn.run(app, port=21003)

View File

@ -67,8 +67,8 @@ if __name__ == "__main__":
worker = ChatGLMWorker( worker = ChatGLMWorker(
controller_addr="http://127.0.0.1:20001", controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20003", worker_addr="http://127.0.0.1:21001",
) )
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
uvicorn.run(app, port=20003) uvicorn.run(app, port=21001)

View File

@ -0,0 +1,22 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.fangzhou import request_volc_api
from pprint import pprint
import pytest
@pytest.mark.parametrize("version", ["chatglm-6b-model"])
def test_qianfan(version):
messages = [{"role": "user", "content": "hello"}]
print("\n" + version + "\n")
i = 1
for x in request_volc_api(messages, version=version):
print(type(x))
pprint(x)
if chunk := x.choice.message.content:
print(chunk)
assert x.choice.message
i += 1

View File

@ -8,7 +8,7 @@ from pprint import pprint
import pytest import pytest
@pytest.mark.parametrize("version", MODEL_VERSIONS.keys()) @pytest.mark.parametrize("version", list(MODEL_VERSIONS.keys())[:2])
def test_qianfan(version): def test_qianfan(version):
messages = [{"role": "user", "content": "你好"}] messages = [{"role": "user", "content": "你好"}]
print("\n" + version + "\n") print("\n" + version + "\n")