使火山方舟正常工作,添加错误处理和测试用例

This commit is contained in:
liunux4odoo 2023-09-17 00:21:13 +08:00
parent 745a105bae
commit 9a7beef270
8 changed files with 95 additions and 42 deletions

View File

@ -2,3 +2,4 @@ from .zhipu import ChatGLMWorker
from .minimax import MiniMaxWorker from .minimax import MiniMaxWorker
from .xinghuo import XingHuoWorker from .xinghuo import XingHuoWorker
from .qianfan import QianFanWorker from .qianfan import QianFanWorker
from .fangzhou import FangZhouWorker

View File

@ -3,9 +3,52 @@ from configs.model_config import TEMPERATURE
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
import json import json
from pprint import pprint
from server.utils import get_model_worker_config
from typing import List, Literal, Dict from typing import List, Literal, Dict
def request_volc_api(
messages: List[Dict],
model_name: str = "fangzhou-api",
version: str = "chatglm-6b-model",
temperature: float = TEMPERATURE,
api_key: str = None,
secret_key: str = None,
):
from volcengine.maas import MaasService, MaasException, ChatRole
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
config = get_model_worker_config(model_name)
version = version or config.get("version")
version_url = config.get("version_url")
api_key = api_key or config.get("api_key")
secret_key = secret_key or config.get("secret_key")
maas.set_ak(api_key)
maas.set_sk(secret_key)
# document: "https://www.volcengine.com/docs/82379/1099475"
req = {
"model": {
"name": version,
},
"parameters": {
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
"max_new_tokens": 1000,
"temperature": temperature,
},
"messages": messages,
}
try:
resps = maas.stream_chat(req)
for resp in resps:
yield resp
except MaasException as e:
print(e)
class FangZhouWorker(ApiModelWorker): class FangZhouWorker(ApiModelWorker):
""" """
火山方舟 火山方舟
@ -21,7 +64,7 @@ class FangZhouWorker(ApiModelWorker):
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384) kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小
super().__init__(**kwargs) super().__init__(**kwargs)
config = self.get_config() config = self.get_config()
@ -41,37 +84,24 @@ class FangZhouWorker(ApiModelWorker):
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
super().generate_stream_gate(params) super().generate_stream_gate(params)
from volcengine.maas import MaasService, MaasException, ChatRole
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
maas.set_ak(self.api_key) messages = self.prompt_to_messages(params["prompt"])
maas.set_sk(self.secret_key)
# document: "https://www.volcengine.com/docs/82379/1099475"
req = {
"model": {
"name": self.version,
},
"parameters": {
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
# "max_new_tokens": 1000,
"temperature": params.get("temperature", TEMPERATURE),
},
"messages": [{"role": ChatRole.USER, "content": params["prompt"]}]
}
text = "" text = ""
try:
resps = maas.stream_chat(req) for resp in request_volc_api(messages=messages,
for resp in resps: model_name=self.model_names[0],
text += resp.choice.message.content version=self.version,
yield json.dumps({"error_code": 0, "text": text}, temperature=params.get("temperature", TEMPERATURE),
ensure_ascii=False).encode() + b"\0" ):
except MaasException as e: error = resp.error
print(e) if error.code_n > 0:
data = {"error_code": error.code_n, "text": error.message}
elif chunk := resp.choice.message.content:
text += chunk
data = {"error_code": 0, "text": text}
yield json.dumps(data, ensure_ascii=False).encode() + b"\0"
def get_embeddings(self, params):
def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings
print("embedding") print("embedding")
print(params) print(params)
@ -84,8 +114,8 @@ if __name__ == "__main__":
worker = FangZhouWorker( worker = FangZhouWorker(
controller_addr="http://127.0.0.1:20001", controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20006", worker_addr="http://127.0.0.1:21005",
) )
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
uvicorn.run(app, port=20006) uvicorn.run(app, port=21005)

View File

@ -93,8 +93,8 @@ if __name__ == "__main__":
worker = MiniMaxWorker( worker = MiniMaxWorker(
controller_addr="http://127.0.0.1:20001", controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20004", worker_addr="http://127.0.0.1:21002",
) )
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
uvicorn.run(app, port=20003) uvicorn.run(app, port=21002)

View File

@ -168,8 +168,8 @@ if __name__ == "__main__":
worker = QianFanWorker( worker = QianFanWorker(
controller_addr="http://127.0.0.1:20001", controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20006", worker_addr="http://127.0.0.1:21004"
) )
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
uvicorn.run(app, port=20006) uvicorn.run(app, port=21004)

View File

@ -94,8 +94,8 @@ if __name__ == "__main__":
worker = XingHuoWorker( worker = XingHuoWorker(
controller_addr="http://127.0.0.1:20001", controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20005", worker_addr="http://127.0.0.1:21003",
) )
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
uvicorn.run(app, port=20005) uvicorn.run(app, port=21003)

View File

@ -67,8 +67,8 @@ if __name__ == "__main__":
worker = ChatGLMWorker( worker = ChatGLMWorker(
controller_addr="http://127.0.0.1:20001", controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20003", worker_addr="http://127.0.0.1:21001",
) )
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
uvicorn.run(app, port=20003) uvicorn.run(app, port=21001)

View File

@ -0,0 +1,22 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.fangzhou import request_volc_api
from pprint import pprint
import pytest
@pytest.mark.parametrize("version", ["chatglm-6b-model"])
def test_qianfan(version):
messages = [{"role": "user", "content": "hello"}]
print("\n" + version + "\n")
i = 1
for x in request_volc_api(messages, version=version):
print(type(x))
pprint(x)
if chunk := x.choice.message.content:
print(chunk)
assert x.choice.message
i += 1

View File

@ -8,7 +8,7 @@ from pprint import pprint
import pytest import pytest
@pytest.mark.parametrize("version", MODEL_VERSIONS.keys()) @pytest.mark.parametrize("version", list(MODEL_VERSIONS.keys())[:2])
def test_qianfan(version): def test_qianfan(version):
messages = [{"role": "user", "content": "你好"}] messages = [{"role": "user", "content": "你好"}]
print("\n" + version + "\n") print("\n" + version + "\n")