优化在线 API ,支持 completion 和 embedding,简化在线 API 开发方式 (#1886)
* 优化在线 API ,支持 completion 和 embedding,简化在线 API 开发方式 新功能 - 智谱AI、Minimax、千帆、千问 4 个在线模型支持 embeddings(不通过Fastchat,后续会单独提供相关api接口) - 在线模型自动检测传入参数,在传入非 messages 格式的 prompt 时,自动转换为 completion 形式,以支持 completion 接口 开发者: - 重构ApiModelWorker: - 所有在线 API 请求封装到 do_chat 方法:自动传入参数 ApiChatParams,简化参数与配置项的获取;自动处理与fastchat的接口 - 加强 API 请求错误处理,返回更有意义的信息 - 改用 qianfan sdk 重写 qianfan-api - 将所有在线模型的测试用例统一在一起,简化测试用例编写 * Delete requirements_langflow.txt
This commit is contained in:
parent
e74fe2d950
commit
b4c68ddd05
|
|
@ -26,10 +26,11 @@ pytest
|
|||
# scikit-learn
|
||||
# numexpr
|
||||
# vllm==0.1.7; sys_platform == "linux"
|
||||
|
||||
# online api libs
|
||||
zhipuai
|
||||
dashscope>=1.10.0 # qwen
|
||||
# qianfan
|
||||
qianfan
|
||||
# volcengine>=1.0.106 # fangzhou
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from server.chat.utils import History
|
||||
from server.utils import get_prompt_template
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from configs import LLM_MODEL, TEMPERATURE
|
|||
from server.utils import wrap_done, get_OpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
from typing import AsyncIterable, Optional
|
||||
import asyncio
|
||||
from langchain.prompts.chat import PromptTemplate
|
||||
from server.utils import get_prompt_template
|
||||
|
|
@ -15,7 +15,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
|||
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
prompt_name: str = Body("default",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from fastapi.responses import StreamingResponse
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import openai
|
||||
from configs import LLM_MODEL, logger, log_verbose
|
||||
from server.utils import get_model_worker_config, fschat_openai_api_address
|
||||
|
|
@ -16,7 +16,7 @@ class OpenAiChatMsgIn(BaseModel):
|
|||
messages: List[OpenAiMessage]
|
||||
temperature: float = 0.7
|
||||
n: int = 1
|
||||
max_tokens: int = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: List[str] = []
|
||||
stream: bool = False
|
||||
presence_penalty: int = 0
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
split_result: bool = Body(False, description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)")
|
||||
):
|
||||
|
|
|
|||
|
|
@ -2,9 +2,8 @@ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
import threading
|
||||
from configs import (EMBEDDING_MODEL, CHUNK_SIZE, CACHED_VS_NUM,
|
||||
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
||||
logger, log_verbose)
|
||||
from server.utils import embedding_device, get_model_path
|
||||
from contextlib import contextmanager
|
||||
|
|
@ -122,7 +121,9 @@ class EmbeddingsPool(CachePool):
|
|||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=get_model_path(model), chunk_size=CHUNK_SIZE)
|
||||
embeddings = OpenAIEmbeddings(model_name=model, # TODO: 支持Azure
|
||||
openai_api_key=get_model_path(model),
|
||||
chunk_size=CHUNK_SIZE)
|
||||
elif 'bge-' in model:
|
||||
if 'zh' in model:
|
||||
# for chinese model
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .base import *
|
||||
from .zhipu import ChatGLMWorker
|
||||
from .minimax import MiniMaxWorker
|
||||
from .xinghuo import XingHuoWorker
|
||||
|
|
|
|||
|
|
@ -1,46 +1,12 @@
|
|||
import sys
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from server.utils import get_model_worker_config, get_httpx_client
|
||||
from server.model_workers.base import *
|
||||
from server.utils import get_httpx_client
|
||||
from fastchat import conversation as conv
|
||||
import json
|
||||
from typing import List, Dict
|
||||
from configs import TEMPERATURE
|
||||
|
||||
|
||||
def request_azure_api(
|
||||
messages: List[Dict[str, str]],
|
||||
resource_name: str = None,
|
||||
api_key: str = None,
|
||||
deployment_name: str = None,
|
||||
api_version: str = "2023-07-01-preview",
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = None,
|
||||
model_name: str = "azure-api",
|
||||
):
|
||||
config = get_model_worker_config(model_name)
|
||||
data = dict(
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}"
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
'api-key': api_key,
|
||||
}
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line.strip() or "[DONE]" in line:
|
||||
continue
|
||||
if line.startswith("data: "):
|
||||
line = line[6:]
|
||||
resp = json.loads(line)
|
||||
yield resp
|
||||
|
||||
class AzureWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -54,37 +20,40 @@ class AzureWorker(ApiModelWorker):
|
|||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 8192)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
self.resource_name = config.get("resource_name")
|
||||
self.api_key = config.get("api_key")
|
||||
self.api_version = config.get("api_version")
|
||||
self.version = version
|
||||
self.deployment_name = config.get("deployment_name")
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
super().generate_stream_gate(params)
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
data = dict(
|
||||
messages=params.messages,
|
||||
temperature=params.temperature,
|
||||
max_tokens=params.max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}"
|
||||
.format(params.resource_name, params.deployment_name, params.api_version))
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
'api-key': params.api_key,
|
||||
}
|
||||
|
||||
messages = self.prompt_to_messages(params["prompt"])
|
||||
text = ""
|
||||
for resp in request_azure_api(messages=messages,
|
||||
resource_name=self.resource_name,
|
||||
api_key=self.api_key,
|
||||
api_version=self.api_version,
|
||||
deployment_name=self.deployment_name,
|
||||
temperature=params.get("temperature"),
|
||||
max_tokens=params.get("max_tokens")):
|
||||
if choices := resp["choices"]:
|
||||
if chunk := choices[0].get("delta", {}).get("content"):
|
||||
text += chunk
|
||||
yield json.dumps(
|
||||
{
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line.strip() or "[DONE]" in line:
|
||||
continue
|
||||
if line.startswith("data: "):
|
||||
line = line[6:]
|
||||
resp = json.loads(line)
|
||||
if choices := resp["choices"]:
|
||||
if chunk := choices[0].get("delta", {}).get("content"):
|
||||
text += chunk
|
||||
yield {
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
@ -95,7 +64,7 @@ class AzureWorker(ApiModelWorker):
|
|||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
system_message="You are a helpful, respectful and honest assistant.",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
|
|
|
|||
|
|
@ -1,18 +1,14 @@
|
|||
# import os
|
||||
# import sys
|
||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from server.utils import get_model_worker_config, get_httpx_client
|
||||
from server.model_workers.base import *
|
||||
from server.utils import get_httpx_client
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Literal, Dict
|
||||
from configs import TEMPERATURE
|
||||
|
||||
|
||||
def calculate_md5(input_string):
|
||||
|
|
@ -22,56 +18,12 @@ def calculate_md5(input_string):
|
|||
return encrypted
|
||||
|
||||
|
||||
def request_baichuan_api(
|
||||
messages: List[Dict[str, str]],
|
||||
api_key: str = None,
|
||||
secret_key: str = None,
|
||||
version: str = "Baichuan2-53B",
|
||||
temperature: float = TEMPERATURE,
|
||||
model_name: str = "baichuan-api",
|
||||
):
|
||||
config = get_model_worker_config(model_name)
|
||||
api_key = api_key or config.get("api_key")
|
||||
secret_key = secret_key or config.get("secret_key")
|
||||
version = version or config.get("version")
|
||||
|
||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||
data = {
|
||||
"model": version,
|
||||
"messages": messages,
|
||||
"parameters": {"temperature": temperature}
|
||||
}
|
||||
|
||||
json_data = json.dumps(data)
|
||||
time_stamp = int(time.time())
|
||||
signature = calculate_md5(secret_key + json_data + str(time_stamp))
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + api_key,
|
||||
"X-BC-Request-Id": "your requestId",
|
||||
"X-BC-Timestamp": str(time_stamp),
|
||||
"X-BC-Signature": signature,
|
||||
"X-BC-Sign-Algo": "MD5",
|
||||
}
|
||||
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
resp = json.loads(line)
|
||||
yield resp
|
||||
|
||||
|
||||
class BaiChuanWorker(ApiModelWorker):
|
||||
BASE_URL = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||
SUPPORT_MODELS = ["Baichuan2-53B"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
model_names: List[str] = ["baichuan-api"],
|
||||
version: Literal["Baichuan2-53B"] = "Baichuan2-53B",
|
||||
**kwargs,
|
||||
|
|
@ -79,40 +31,48 @@ class BaiChuanWorker(ApiModelWorker):
|
|||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 32768)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
config = self.get_config()
|
||||
self.version = config.get("version",version)
|
||||
self.api_key = config.get("api_key")
|
||||
self.secret_key = config.get("secret_key")
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
super().generate_stream_gate(params)
|
||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||
data = {
|
||||
"model": params.version,
|
||||
"messages": params.messages,
|
||||
"parameters": {"temperature": params.temperature}
|
||||
}
|
||||
|
||||
messages = self.prompt_to_messages(params["prompt"])
|
||||
json_data = json.dumps(data)
|
||||
time_stamp = int(time.time())
|
||||
signature = calculate_md5(params.secret_key + json_data + str(time_stamp))
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + params.api_key,
|
||||
"X-BC-Request-Id": "your requestId",
|
||||
"X-BC-Timestamp": str(time_stamp),
|
||||
"X-BC-Signature": signature,
|
||||
"X-BC-Sign-Algo": "MD5",
|
||||
}
|
||||
|
||||
text = ""
|
||||
for resp in request_baichuan_api(messages=messages,
|
||||
api_key=self.api_key,
|
||||
secret_key=self.secret_key,
|
||||
version=self.version,
|
||||
temperature=params.get("temperature")):
|
||||
if resp["code"] == 0:
|
||||
text += resp["data"]["messages"][-1]["content"]
|
||||
yield json.dumps(
|
||||
{
|
||||
"error_code": resp["code"],
|
||||
"text": text
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
else:
|
||||
yield json.dumps(
|
||||
{
|
||||
"error_code": resp["code"],
|
||||
"text": resp["msg"]
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
resp = json.loads(line)
|
||||
if resp["code"] == 0:
|
||||
text += resp["data"]["messages"][-1]["content"]
|
||||
yield {
|
||||
"error_code": resp["code"],
|
||||
"text": text
|
||||
}
|
||||
else:
|
||||
yield {
|
||||
"error_code": resp["code"],
|
||||
"text": resp["msg"]
|
||||
}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
|
|||
|
|
@ -1,36 +1,106 @@
|
|||
from fastchat.conversation import Conversation
|
||||
from configs.basic_config import LOG_PATH
|
||||
from configs import LOG_PATH, TEMPERATURE
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.base_model_worker import BaseModelWorker
|
||||
import uuid
|
||||
import json
|
||||
import sys
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, root_validator
|
||||
import fastchat
|
||||
import asyncio
|
||||
from typing import Dict, List
|
||||
from server.utils import get_model_worker_config
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
|
||||
|
||||
# 恢复被fastchat覆盖的标准输出
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
|
||||
class ApiModelOutMsg(BaseModel):
|
||||
error_code: int = 0
|
||||
text: str
|
||||
class ApiConfigParams(BaseModel):
|
||||
'''
|
||||
在线API配置参数,未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取
|
||||
'''
|
||||
api_base_url: Optional[str] = None
|
||||
api_proxy: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
group_id: Optional[str] = None # for minimax
|
||||
is_pro: bool = False # for minimax
|
||||
|
||||
APPID: Optional[str] = None # for xinghuo
|
||||
APISecret: Optional[str] = None # for xinghuo
|
||||
is_v2: bool = False # for xinghuo
|
||||
|
||||
worker_name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_config(cls, v: Dict) -> Dict:
|
||||
if config := get_model_worker_config(v.get("worker_name")):
|
||||
for n in cls.__fields__:
|
||||
if n in config:
|
||||
v[n] = config[n]
|
||||
return v
|
||||
|
||||
def load_config(self, worker_name: str):
|
||||
self.worker_name = worker_name
|
||||
if config := get_model_worker_config(worker_name):
|
||||
for n in self.__fields__:
|
||||
if n in config:
|
||||
setattr(self, n, config[n])
|
||||
return self
|
||||
|
||||
|
||||
class ApiModelParams(ApiConfigParams):
|
||||
'''
|
||||
模型配置参数
|
||||
'''
|
||||
version: Optional[str] = None
|
||||
version_url: Optional[str] = None
|
||||
api_version: Optional[str] = None # for azure
|
||||
deployment_name: Optional[str] = None # for azure
|
||||
resource_name: Optional[str] = None # for azure
|
||||
|
||||
temperature: float = TEMPERATURE
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = 1.0
|
||||
|
||||
|
||||
class ApiChatParams(ApiModelParams):
|
||||
'''
|
||||
chat请求参数
|
||||
'''
|
||||
messages: List[Dict[str, str]]
|
||||
system_message: Optional[str] = None # for minimax
|
||||
role_meta: Dict = {} # for minimax
|
||||
|
||||
|
||||
class ApiCompletionParams(ApiModelParams):
|
||||
prompt: str
|
||||
|
||||
|
||||
class ApiEmbeddingsParams(ApiConfigParams):
|
||||
texts: List[str]
|
||||
embed_model: Optional[str] = None
|
||||
to_query: bool = False # for minimax
|
||||
|
||||
|
||||
class ApiModelWorker(BaseModelWorker):
|
||||
BASE_URL: str
|
||||
SUPPORT_MODELS: List
|
||||
DEFAULT_EMBED_MODEL: str = None # None means not support embedding
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_names: List[str],
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
context_len: int = 2048,
|
||||
no_register: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
|
||||
|
|
@ -40,9 +110,18 @@ class ApiModelWorker(BaseModelWorker):
|
|||
controller_addr=controller_addr,
|
||||
worker_addr=worker_addr,
|
||||
**kwargs)
|
||||
import sys
|
||||
# 恢复被fastchat覆盖的标准输出
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
self.context_len = context_len
|
||||
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
|
||||
self.init_heart_beat()
|
||||
self.version = None
|
||||
|
||||
if not no_register and self.controller_addr:
|
||||
self.init_heart_beat()
|
||||
|
||||
|
||||
def count_token(self, params):
|
||||
# TODO:需要完善
|
||||
|
|
@ -50,33 +129,108 @@ class ApiModelWorker(BaseModelWorker):
|
|||
prompt = params["prompt"]
|
||||
return {"count": len(str(prompt)), "error_code": 0}
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
def generate_stream_gate(self, params: Dict):
|
||||
self.call_ct += 1
|
||||
|
||||
|
||||
try:
|
||||
prompt = params["prompt"]
|
||||
if self._is_chat(prompt):
|
||||
messages = self.prompt_to_messages(prompt)
|
||||
messages = self.validate_messages(messages)
|
||||
else: # 使用chat模仿续写功能,不支持历史消息
|
||||
messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]
|
||||
|
||||
p = ApiChatParams(
|
||||
messages=messages,
|
||||
temperature=params.get("temperature"),
|
||||
top_p=params.get("top_p"),
|
||||
max_tokens=params.get("max_new_tokens"),
|
||||
version=self.version,
|
||||
)
|
||||
for resp in self.do_chat(p):
|
||||
yield self._jsonify(resp)
|
||||
except Exception as e:
|
||||
yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误:{e}"})
|
||||
|
||||
def generate_gate(self, params):
|
||||
for x in self.generate_stream_gate(params):
|
||||
pass
|
||||
return json.loads(x[:-1].decode())
|
||||
try:
|
||||
for x in self.generate_stream_gate(params):
|
||||
...
|
||||
return json.loads(x[:-1].decode())
|
||||
except Exception as e:
|
||||
return {"error_code": 500, "text": str(e)}
|
||||
|
||||
|
||||
# 需要用户自定义的方法
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
'''
|
||||
执行Chat的方法,默认使用模块里面的chat函数。
|
||||
要求返回形式:{"error_code": int, "text": str}
|
||||
'''
|
||||
return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}
|
||||
|
||||
# def do_completion(self, p: ApiCompletionParams) -> Dict:
|
||||
# '''
|
||||
# 执行Completion的方法,默认使用模块里面的completion函数。
|
||||
# 要求返回形式:{"error_code": int, "text": str}
|
||||
# '''
|
||||
# return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
'''
|
||||
执行Embeddings的方法,默认使用模块里面的embed_documents函数。
|
||||
要求返回形式:{"code": int, "embeddings": List[List[float]], "msg": str}
|
||||
'''
|
||||
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
# print(params)
|
||||
# fastchat对LLM做Embeddings限制很大,似乎只能使用openai的。
|
||||
# 在前端通过OpenAIEmbeddings发起的请求直接出错,无法请求过来。
|
||||
print("get_embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
|
||||
'''
|
||||
有些API对mesages有特殊格式,可以重写该函数替换默认的messages。
|
||||
之所以跟prompt_to_messages分开,是因为他们应用场景不同、参数不同
|
||||
'''
|
||||
return messages
|
||||
|
||||
|
||||
# help methods
|
||||
def get_config(self):
|
||||
from server.utils import get_model_worker_config
|
||||
return get_model_worker_config(self.model_names[0])
|
||||
@property
|
||||
def user_role(self):
|
||||
return self.conv.roles[0]
|
||||
|
||||
@property
|
||||
def ai_role(self):
|
||||
return self.conv.roles[1]
|
||||
|
||||
def _jsonify(self, data: Dict) -> str:
|
||||
'''
|
||||
将chat函数返回的结果按照fastchat openai-api-server的格式返回
|
||||
'''
|
||||
return json.dumps(data, ensure_ascii=False).encode() + b"\0"
|
||||
|
||||
def _is_chat(self, prompt: str) -> bool:
|
||||
'''
|
||||
检查prompt是否由chat messages拼接而来
|
||||
TODO: 存在误判的可能,也许从fastchat直接传入原始messages是更好的做法
|
||||
'''
|
||||
key = f"{self.conv.sep}{self.user_role}:"
|
||||
return key in prompt
|
||||
|
||||
def prompt_to_messages(self, prompt: str) -> List[Dict]:
|
||||
'''
|
||||
将prompt字符串拆分成messages.
|
||||
'''
|
||||
result = []
|
||||
user_role = self.conv.roles[0]
|
||||
ai_role = self.conv.roles[1]
|
||||
user_role = self.user_role
|
||||
ai_role = self.ai_role
|
||||
user_start = user_role + ":"
|
||||
ai_start = ai_role + ":"
|
||||
for msg in prompt.split(self.conv.sep)[1:-1]:
|
||||
|
|
@ -89,3 +243,7 @@ class ApiModelWorker(BaseModelWorker):
|
|||
else:
|
||||
raise RuntimeError(f"unknown role in msg: {msg}")
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def can_embedding(cls):
|
||||
return cls.DEFAULT_EMBED_MODEL is not None
|
||||
|
|
|
|||
|
|
@ -1,98 +1,58 @@
|
|||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from configs.model_config import TEMPERATURE
|
||||
from server.model_workers.base import *
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from pprint import pprint
|
||||
from server.utils import get_model_worker_config
|
||||
from typing import List, Literal, Dict
|
||||
|
||||
|
||||
def request_volc_api(
|
||||
messages: List[Dict],
|
||||
model_name: str = "fangzhou-api",
|
||||
version: str = "chatglm-6b-model",
|
||||
temperature: float = TEMPERATURE,
|
||||
api_key: str = None,
|
||||
secret_key: str = None,
|
||||
):
|
||||
from volcengine.maas import MaasService, MaasException, ChatRole
|
||||
|
||||
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
|
||||
config = get_model_worker_config(model_name)
|
||||
version = version or config.get("version")
|
||||
version_url = config.get("version_url")
|
||||
api_key = api_key or config.get("api_key")
|
||||
secret_key = secret_key or config.get("secret_key")
|
||||
|
||||
maas.set_ak(api_key)
|
||||
maas.set_sk(secret_key)
|
||||
|
||||
# document: "https://www.volcengine.com/docs/82379/1099475"
|
||||
req = {
|
||||
"model": {
|
||||
"name": version,
|
||||
},
|
||||
"parameters": {
|
||||
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
|
||||
"max_new_tokens": 1000,
|
||||
"temperature": temperature,
|
||||
},
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
try:
|
||||
resps = maas.stream_chat(req)
|
||||
for resp in resps:
|
||||
yield resp
|
||||
except MaasException as e:
|
||||
print(e)
|
||||
|
||||
|
||||
class FangZhouWorker(ApiModelWorker):
|
||||
"""
|
||||
火山方舟
|
||||
"""
|
||||
SUPPORT_MODELS = ["chatglm-6b-model"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
|
||||
model_names: List[str] = ["fangzhou-api"],
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
**kwargs,
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["fangzhou-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
config = self.get_config()
|
||||
self.version = version
|
||||
self.api_key = config.get("api_key")
|
||||
self.secret_key = config.get("secret_key")
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
super().generate_stream_gate(params)
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
from volcengine.maas import MaasService
|
||||
|
||||
params.load_config(self.model_names[0])
|
||||
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
|
||||
maas.set_ak(params.api_key)
|
||||
maas.set_sk(params.secret_key)
|
||||
|
||||
# document: "https://www.volcengine.com/docs/82379/1099475"
|
||||
req = {
|
||||
"model": {
|
||||
"name": params.version,
|
||||
},
|
||||
"parameters": {
|
||||
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
|
||||
"max_new_tokens": params.max_tokens,
|
||||
"temperature": params.temperature,
|
||||
},
|
||||
"messages": params.messages,
|
||||
}
|
||||
|
||||
messages = self.prompt_to_messages(params["prompt"])
|
||||
text = ""
|
||||
|
||||
for resp in request_volc_api(messages=messages,
|
||||
model_name=self.model_names[0],
|
||||
version=self.version,
|
||||
temperature=params.get("temperature", TEMPERATURE),
|
||||
):
|
||||
for resp in maas.stream_chat(req):
|
||||
error = resp.error
|
||||
if error.code_n > 0:
|
||||
data = {"error_code": error.code_n, "text": error.message}
|
||||
yield {"error_code": error.code_n, "text": error.message}
|
||||
elif chunk := resp.choice.message.content:
|
||||
text += chunk
|
||||
data = {"error_code": 0, "text": text}
|
||||
yield json.dumps(data, ensure_ascii=False).encode() + b"\0"
|
||||
yield {"error_code": 0, "text": text}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
|
|||
|
|
@ -1,77 +1,108 @@
|
|||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from server.model_workers.base import *
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
from server.utils import get_httpx_client
|
||||
from pprint import pprint
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
class MiniMaxWorker(ApiModelWorker):
|
||||
BASE_URL = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
|
||||
DEFAULT_EMBED_MODEL = "embo-01"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["minimax-api"],
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: str = "abab5.5-chat",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def prompt_to_messages(self, prompt: str) -> List[Dict]:
|
||||
result = super().prompt_to_messages(prompt)
|
||||
messages = [{"sender_type": x["role"], "text": x["content"]} for x in result]
|
||||
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
|
||||
role_maps = {
|
||||
"user": self.user_role,
|
||||
"assistant": self.ai_role,
|
||||
"system": "system",
|
||||
}
|
||||
messages = [{"sender_type": role_maps[x["role"]], "text": x["content"]} for x in messages]
|
||||
return messages
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
# 按照官网推荐,直接调用abab 5.5模型
|
||||
# TODO: 支持指定回复要求,支持指定用户名称、AI名称
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
super().generate_stream_gate(params)
|
||||
config = self.get_config()
|
||||
group_id = config.get("group_id")
|
||||
api_key = config.get("api_key")
|
||||
|
||||
pro = "_pro" if config.get("is_pro") else ""
|
||||
url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
|
||||
pro = "_pro" if params.is_pro else ""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Authorization": f"Bearer {params.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
messages = self.validate_messages(params.messages)
|
||||
data = {
|
||||
"model": "abab5.5-chat",
|
||||
"model": params.version,
|
||||
"stream": True,
|
||||
"tokens_to_generate": 1024, # TODO: 1024为官网默认值
|
||||
"mask_sensitive_info": True,
|
||||
"messages": self.prompt_to_messages(params["prompt"]),
|
||||
"temperature": params.get("temperature"),
|
||||
"top_p": params.get("top_p"),
|
||||
"bot_setting": [],
|
||||
"messages": messages,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"tokens_to_generate": params.max_tokens or 1024,
|
||||
# TODO: 以下参数为minimax特有,传入空值会出错。
|
||||
# "prompt": params.system_message or self.conv.system_message,
|
||||
# "bot_setting": [],
|
||||
# "role_meta": params.role_meta,
|
||||
}
|
||||
print("request data sent to minimax:")
|
||||
pprint(data)
|
||||
|
||||
with get_httpx_client() as client:
|
||||
response = client.stream("POST",
|
||||
self.BASE_URL.format(pro=pro, group_id=group_id),
|
||||
url.format(pro=pro, group_id=params.group_id),
|
||||
headers=headers,
|
||||
json=data)
|
||||
with response as r:
|
||||
text = ""
|
||||
for e in r.iter_text():
|
||||
if e.startswith("data: "): # 真是优秀的返回
|
||||
data = json.loads(e[6:])
|
||||
if not data.get("usage"):
|
||||
if choices := data.get("choices"):
|
||||
chunk = choices[0].get("delta", "").strip()
|
||||
if chunk:
|
||||
print(chunk)
|
||||
text += chunk
|
||||
yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0"
|
||||
|
||||
if not e.startswith("data: "): # 真是优秀的返回
|
||||
yield {"error_code": 500, "text": f"minimax返回错误的结果:{e}"}
|
||||
continue
|
||||
|
||||
data = json.loads(e[6:])
|
||||
if data.get("usage"):
|
||||
break
|
||||
|
||||
if choices := data.get("choices"):
|
||||
if chunk := choices[0].get("delta", ""):
|
||||
text += chunk
|
||||
yield {"error_code": 0, "text": text}
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
url = f"https://api.minimax.chat/v1/embeddings?GroupId={params.group_id}"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {params.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": params.embed_model or self.DEFAULT_EMBED_MODEL,
|
||||
"texts": params.texts,
|
||||
"type": "query" if params.to_query else "db",
|
||||
}
|
||||
|
||||
with get_httpx_client() as client:
|
||||
r = client.post(url, headers=headers, json=data).json()
|
||||
if embeddings := r.get("vectors"):
|
||||
return {"code": 200, "embeddings": embeddings}
|
||||
elif error := r.get("base_resp"):
|
||||
return {"code": error["status_code"], "msg": error["status_msg"]}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
|
|
@ -81,7 +112,7 @@ class MiniMaxWorker(ApiModelWorker):
|
|||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
system_message="你是MiniMax自主研发的大型语言模型,回答问题简洁有条理。",
|
||||
messages=[],
|
||||
roles=["USER", "BOT"],
|
||||
sep="\n### ",
|
||||
|
|
|
|||
|
|
@ -1,155 +1,169 @@
|
|||
import sys
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from configs.model_config import TEMPERATURE
|
||||
from server.model_workers.base import *
|
||||
from fastchat import conversation as conv
|
||||
import json
|
||||
from cachetools import cached, TTLCache
|
||||
from server.utils import get_model_worker_config, get_httpx_client
|
||||
import sys
|
||||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
from typing import List, Literal, Dict
|
||||
|
||||
MODEL_VERSIONS = {
|
||||
"ernie-bot": "completions",
|
||||
"ernie-bot-4": "completions_pro",
|
||||
"ernie-bot-turbo": "eb-instant",
|
||||
"bloomz-7b": "bloomz_7b1",
|
||||
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
|
||||
"llama2-7b-chat": "llama_2_7b",
|
||||
"llama2-13b-chat": "llama_2_13b",
|
||||
"llama2-70b-chat": "llama_2_70b",
|
||||
"qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b",
|
||||
"chatglm2-6b-32k": "chatglm2_6b_32k",
|
||||
"aquilachat-7b": "aquilachat_7b",
|
||||
# "linly-llama2-ch-7b": "", # 暂未发布
|
||||
# "linly-llama2-ch-13b": "", # 暂未发布
|
||||
# "chatglm2-6b": "", # 暂未发布
|
||||
# "chatglm2-6b-int4": "", # 暂未发布
|
||||
# "falcon-7b": "", # 暂未发布
|
||||
# "falcon-180b-chat": "", # 暂未发布
|
||||
# "falcon-40b": "", # 暂未发布
|
||||
# "rwkv4-world": "", # 暂未发布
|
||||
# "rwkv5-world": "", # 暂未发布
|
||||
# "rwkv4-pile-14b": "", # 暂未发布
|
||||
# "rwkv4-raven-14b": "", # 暂未发布
|
||||
# "open-llama-7b": "", # 暂未发布
|
||||
# "dolly-12b": "", # 暂未发布
|
||||
# "mpt-7b-instruct": "", # 暂未发布
|
||||
# "mpt-30b-instruct": "", # 暂未发布
|
||||
# "OA-Pythia-12B-SFT-4": "", # 暂未发布
|
||||
# "xverse-13b": "", # 暂未发布
|
||||
|
||||
# # 以下为企业测试,需要单独申请
|
||||
# "flan-ul2": "",
|
||||
# "Cerebras-GPT-6.7B": ""
|
||||
# "Pythia-6.9B": ""
|
||||
}
|
||||
# MODEL_VERSIONS = {
|
||||
# "ernie-bot": "completions",
|
||||
# "ernie-bot-turbo": "eb-instant",
|
||||
# "bloomz-7b": "bloomz_7b1",
|
||||
# "qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
|
||||
# "llama2-7b-chat": "llama_2_7b",
|
||||
# "llama2-13b-chat": "llama_2_13b",
|
||||
# "llama2-70b-chat": "llama_2_70b",
|
||||
# "qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b",
|
||||
# "chatglm2-6b-32k": "chatglm2_6b_32k",
|
||||
# "aquilachat-7b": "aquilachat_7b",
|
||||
# # "linly-llama2-ch-7b": "", # 暂未发布
|
||||
# # "linly-llama2-ch-13b": "", # 暂未发布
|
||||
# # "chatglm2-6b": "", # 暂未发布
|
||||
# # "chatglm2-6b-int4": "", # 暂未发布
|
||||
# # "falcon-7b": "", # 暂未发布
|
||||
# # "falcon-180b-chat": "", # 暂未发布
|
||||
# # "falcon-40b": "", # 暂未发布
|
||||
# # "rwkv4-world": "", # 暂未发布
|
||||
# # "rwkv5-world": "", # 暂未发布
|
||||
# # "rwkv4-pile-14b": "", # 暂未发布
|
||||
# # "rwkv4-raven-14b": "", # 暂未发布
|
||||
# # "open-llama-7b": "", # 暂未发布
|
||||
# # "dolly-12b": "", # 暂未发布
|
||||
# # "mpt-7b-instruct": "", # 暂未发布
|
||||
# # "mpt-30b-instruct": "", # 暂未发布
|
||||
# # "OA-Pythia-12B-SFT-4": "", # 暂未发布
|
||||
# # "xverse-13b": "", # 暂未发布
|
||||
|
||||
# # # 以下为企业测试,需要单独申请
|
||||
# # "flan-ul2": "",
|
||||
# # "Cerebras-GPT-6.7B": ""
|
||||
# # "Pythia-6.9B": ""
|
||||
# }
|
||||
|
||||
|
||||
@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
||||
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
使用 AK,SK 生成鉴权签名(Access Token)
|
||||
:return: access_token,或是None(如果错误)
|
||||
"""
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
|
||||
try:
|
||||
with get_httpx_client() as client:
|
||||
return client.get(url, params=params).json().get("access_token")
|
||||
except Exception as e:
|
||||
print(f"failed to get token from baidu: {e}")
|
||||
|
||||
|
||||
def request_qianfan_api(
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = TEMPERATURE,
|
||||
model_name: str = "qianfan-api",
|
||||
version: str = None,
|
||||
) -> Dict:
|
||||
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
|
||||
'/{model_version}?access_token={access_token}'
|
||||
config = get_model_worker_config(model_name)
|
||||
version = version or config.get("version")
|
||||
version_url = config.get("version_url")
|
||||
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
|
||||
if not access_token:
|
||||
yield {
|
||||
"error_code": 403,
|
||||
"error_msg": f"failed to get access token. have you set the correct api_key and secret key?",
|
||||
}
|
||||
|
||||
url = BASE_URL.format(
|
||||
model_version=version_url or MODEL_VERSIONS[version],
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"stream": True
|
||||
}
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
}
|
||||
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
if line.startswith("data: "):
|
||||
line = line[6:]
|
||||
resp = json.loads(line)
|
||||
breakpoint()
|
||||
yield resp
|
||||
# @cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
||||
# def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
||||
# """
|
||||
# 使用 AK,SK 生成鉴权签名(Access Token)
|
||||
# :return: access_token,或是None(如果错误)
|
||||
# """
|
||||
# url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
# params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
|
||||
# try:
|
||||
# with get_httpx_client() as client:
|
||||
# return client.get(url, params=params).json().get("access_token")
|
||||
# except Exception as e:
|
||||
# print(f"failed to get token from baidu: {e}")
|
||||
|
||||
|
||||
class QianFanWorker(ApiModelWorker):
|
||||
"""
|
||||
百度千帆
|
||||
"""
|
||||
DEFAULT_EMBED_MODEL = "bge-large-zh"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
|
||||
model_names: List[str] = ["ernie-api"],
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
**kwargs,
|
||||
self,
|
||||
*,
|
||||
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
|
||||
model_names: List[str] = ["qianfan-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
config = self.get_config()
|
||||
self.version = version
|
||||
self.api_key = config.get("api_key")
|
||||
self.secret_key = config.get("secret_key")
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
messages = self.prompt_to_messages(params["prompt"])
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
import qianfan
|
||||
|
||||
comp = qianfan.ChatCompletion(model=params.version,
|
||||
endpoint=params.version_url,
|
||||
ak=params.api_key,
|
||||
sk=params.secret_key,)
|
||||
text = ""
|
||||
for resp in request_qianfan_api(messages,
|
||||
temperature=params.get("temperature"),
|
||||
model_name=self.model_names[0]):
|
||||
if "result" in resp.keys():
|
||||
text += resp["result"]
|
||||
yield json.dumps({
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
for resp in comp.do(messages=params.messages,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
stream=True):
|
||||
if resp.code == 200:
|
||||
if chunk := resp.body.get("result"):
|
||||
text += chunk
|
||||
yield {
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
}
|
||||
else:
|
||||
yield json.dumps({
|
||||
"error_code": resp["error_code"],
|
||||
"text": resp["error_msg"]
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
yield {
|
||||
"error_code": resp.code,
|
||||
"text": str(resp.body),
|
||||
}
|
||||
|
||||
# BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\
|
||||
# '/{model_version}?access_token={access_token}'
|
||||
|
||||
# access_token = get_baidu_access_token(params.api_key, params.secret_key)
|
||||
# if not access_token:
|
||||
# yield {
|
||||
# "error_code": 403,
|
||||
# "text": f"failed to get access token. have you set the correct api_key and secret key?",
|
||||
# }
|
||||
|
||||
# url = BASE_URL.format(
|
||||
# model_version=params.version_url or MODEL_VERSIONS[params.version],
|
||||
# access_token=access_token,
|
||||
# )
|
||||
# payload = {
|
||||
# "messages": params.messages,
|
||||
# "temperature": params.temperature,
|
||||
# "stream": True
|
||||
# }
|
||||
# headers = {
|
||||
# 'Content-Type': 'application/json',
|
||||
# 'Accept': 'application/json',
|
||||
# }
|
||||
|
||||
# text = ""
|
||||
# with get_httpx_client() as client:
|
||||
# with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
# for line in response.iter_lines():
|
||||
# if not line.strip():
|
||||
# continue
|
||||
# if line.startswith("data: "):
|
||||
# line = line[6:]
|
||||
# resp = json.loads(line)
|
||||
|
||||
# if "result" in resp.keys():
|
||||
# text += resp["result"]
|
||||
# yield {
|
||||
# "error_code": 0,
|
||||
# "text": text
|
||||
# }
|
||||
# else:
|
||||
# yield {
|
||||
# "error_code": resp["error_code"],
|
||||
# "text": resp["error_msg"]
|
||||
# }
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
import qianfan
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
embed = qianfan.Embedding(ak=params.api_key, sk=params.secret_key)
|
||||
resp = embed.do(texts = params.texts, model=params.embed_model or self.DEFAULT_EMBED_MODEL)
|
||||
if resp.code == 200:
|
||||
embeddings = [x.embedding for x in resp.body.get("data", [])]
|
||||
return {"code": 200, "embeddings": embeddings}
|
||||
else:
|
||||
return {"code": resp.code, "msg": str(resp.body)}
|
||||
|
||||
|
||||
# TODO: qianfan支持续写模型
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
|
|
@ -159,7 +173,7 @@ class QianFanWorker(ApiModelWorker):
|
|||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
|
|
|
|||
|
|
@ -7,94 +7,68 @@ from http import HTTPStatus
|
|||
from typing import List, Literal, Dict
|
||||
|
||||
from fastchat import conversation as conv
|
||||
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from server.utils import get_model_worker_config
|
||||
|
||||
|
||||
def request_qwen_api(
|
||||
messages: List[Dict[str, str]],
|
||||
api_key: str = None,
|
||||
version: str = "qwen-turbo",
|
||||
temperature: float = TEMPERATURE,
|
||||
model_name: str = "qwen-api",
|
||||
):
|
||||
import dashscope
|
||||
|
||||
config = get_model_worker_config(model_name)
|
||||
api_key = api_key or config.get("api_key")
|
||||
version = version or config.get("version")
|
||||
|
||||
gen = dashscope.Generation()
|
||||
responses = gen.call(
|
||||
model=version,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
result_format='message', # set the result is message format.
|
||||
stream=True,
|
||||
)
|
||||
|
||||
text = ""
|
||||
for resp in responses:
|
||||
if resp.status_code != HTTPStatus.OK:
|
||||
yield {
|
||||
"code": resp.status_code,
|
||||
"text": "api not response correctly",
|
||||
}
|
||||
|
||||
if resp["status_code"] == 200:
|
||||
if choices := resp["output"]["choices"]:
|
||||
yield {
|
||||
"code": 200,
|
||||
"text": choices[0]["message"]["content"],
|
||||
}
|
||||
else:
|
||||
yield {
|
||||
"code": resp["status_code"],
|
||||
"text": resp["message"],
|
||||
}
|
||||
from server.model_workers.base import *
|
||||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
|
||||
|
||||
class QwenWorker(ApiModelWorker):
|
||||
DEFAULT_EMBED_MODEL = "text-embedding-v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
|
||||
model_names: List[str] = ["qwen-api"],
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
**kwargs,
|
||||
self,
|
||||
*,
|
||||
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
|
||||
model_names: List[str] = ["qwen-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
config = self.get_config()
|
||||
self.api_key = config.get("api_key")
|
||||
self.version = version
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
messages = self.prompt_to_messages(params["prompt"])
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
import dashscope
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
for resp in request_qwen_api(messages=messages,
|
||||
api_key=self.api_key,
|
||||
version=self.version,
|
||||
temperature=params.get("temperature")):
|
||||
if resp["code"] == 200:
|
||||
yield json.dumps({
|
||||
"error_code": 0,
|
||||
"text": resp["text"]
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
gen = dashscope.Generation()
|
||||
responses = gen.call(
|
||||
model=params.version,
|
||||
temperature=params.temperature,
|
||||
api_key=params.api_key,
|
||||
messages=params.messages,
|
||||
result_format='message', # set the result is message format.
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for resp in responses:
|
||||
if resp["status_code"] == 200:
|
||||
if choices := resp["output"]["choices"]:
|
||||
yield {
|
||||
"error_code": 0,
|
||||
"text": choices[0]["message"]["content"],
|
||||
}
|
||||
else:
|
||||
yield json.dumps({
|
||||
"error_code": resp["code"],
|
||||
"text": resp["text"]
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
yield {
|
||||
"error_code": resp["status_code"],
|
||||
"text": resp["message"],
|
||||
}
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
import dashscope
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=params.embed_model or self.DEFAULT_EMBED_MODEL,
|
||||
input=params.texts, # 最大25行
|
||||
api_key=params.api_key,
|
||||
)
|
||||
if resp["status_code"] != 200:
|
||||
return {"code": resp["status_code"], "msg": resp.message}
|
||||
else:
|
||||
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
|
||||
return {"code": 200, "embeddings": embeddings}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from server.model_workers.base import *
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from server.model_workers import SparkApi
|
||||
import websockets
|
||||
from server.utils import iter_over_async, asyncio
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature):
|
||||
|
|
@ -31,46 +31,40 @@ class XingHuoWorker(ApiModelWorker):
|
|||
self,
|
||||
*,
|
||||
model_names: List[str] = ["xinghuo-api"],
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 8192)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
# TODO: 当前每次对话都要重新连接websocket,确认是否可以保持连接
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
super().generate_stream_gate(params)
|
||||
config = self.get_config()
|
||||
appid = config.get("APPID")
|
||||
api_secret = config.get("APISecret")
|
||||
api_key = config.get("api_key")
|
||||
|
||||
if config.get("is_v2"):
|
||||
if params.is_v2:
|
||||
domain = "generalv2" # v2.0版本
|
||||
Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
|
||||
else:
|
||||
domain = "general" # v1.5版本
|
||||
Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
|
||||
|
||||
question = self.prompt_to_messages(params["prompt"])
|
||||
text = ""
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
for chunk in iter_over_async(
|
||||
request(appid, api_key, api_secret, Spark_url, domain, question, params.get("temperature")),
|
||||
request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages, params.temperature),
|
||||
loop=loop,
|
||||
):
|
||||
if chunk:
|
||||
print(chunk)
|
||||
text += chunk
|
||||
yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0"
|
||||
yield {"error_code": 0, "text": text}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
@ -81,7 +75,7 @@ class XingHuoWorker(ApiModelWorker):
|
|||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
|
|
|
|||
|
|
@ -1,22 +1,20 @@
|
|||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from server.model_workers.base import *
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Literal
|
||||
from typing import List, Dict, Iterator, Literal
|
||||
|
||||
|
||||
class ChatGLMWorker(ApiModelWorker):
|
||||
BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||
SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"]
|
||||
DEFAULT_EMBED_MODEL = "text_embedding"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["zhipu-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std",
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
|
|
@ -24,27 +22,45 @@ class ChatGLMWorker(ApiModelWorker):
|
|||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
|
||||
# TODO: 维护request_id
|
||||
import zhipuai
|
||||
|
||||
super().generate_stream_gate(params)
|
||||
zhipuai.api_key = self.get_config().get("api_key")
|
||||
params.load_config(self.model_names[0])
|
||||
zhipuai.api_key = params.api_key
|
||||
|
||||
response = zhipuai.model_api.sse_invoke(
|
||||
model=self.version,
|
||||
prompt=[{"role": "user", "content": params["prompt"]}],
|
||||
temperature=params.get("temperature"),
|
||||
top_p=params.get("top_p"),
|
||||
model=params.version,
|
||||
prompt=params.messages,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
incremental=False,
|
||||
)
|
||||
for e in response.events():
|
||||
if e.event == "add":
|
||||
yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0"
|
||||
# TODO: 更健壮的消息处理
|
||||
# elif e.event == "finish":
|
||||
# ...
|
||||
|
||||
yield {"error_code": 0, "text": e.data}
|
||||
elif e.event in ["error", "interrupted"]:
|
||||
yield {"error_code": 500, "text": str(e)}
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
import zhipuai
|
||||
|
||||
params.load_config(self.model_names[0])
|
||||
zhipuai.api_key = params.api_key
|
||||
|
||||
embeddings = []
|
||||
try:
|
||||
for t in params.texts:
|
||||
response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t)
|
||||
if response["code"] == 200:
|
||||
embeddings.append(response["data"]["embedding"])
|
||||
else:
|
||||
return response # dict with code & msg
|
||||
except Exception as e:
|
||||
return {"code": 500, "msg": f"对文本向量化时出错:{e}"}
|
||||
|
||||
return {"code": 200, "embeddings": embeddings}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
|
|
@ -56,7 +72,7 @@ class ChatGLMWorker(ApiModelWorker):
|
|||
name=self.model_names[0],
|
||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["Human", "Assistant"],
|
||||
roles=["Human", "Assistant", "System"],
|
||||
sep="\n###",
|
||||
stop_str="###",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -716,3 +716,15 @@ def get_server_configs() -> Dict:
|
|||
}
|
||||
|
||||
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
|
||||
|
||||
|
||||
def list_online_embed_models() -> List[str]:
|
||||
from server import model_workers
|
||||
|
||||
ret = []
|
||||
for k, v in list_config_llm_models()["online"].items():
|
||||
provider = v.get("provider")
|
||||
worker_class = getattr(model_workers, provider, None)
|
||||
if worker_class is not None and worker_class.can_embedding():
|
||||
ret.append(k)
|
||||
return ret
|
||||
|
|
|
|||
|
|
@ -1,26 +0,0 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.model_workers.azure import request_azure_api
|
||||
import pytest
|
||||
from configs import TEMPERATURE, ONLINE_LLM_MODEL
|
||||
@pytest.mark.parametrize("version", ["azure-api"])
|
||||
def test_azure(version):
|
||||
messages = [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "介绍一下自己"}]
|
||||
for resp in request_azure_api(
|
||||
messages=messages,
|
||||
api_base_url=ONLINE_LLM_MODEL["azure-api"]["api_base_url"],
|
||||
api_key=ONLINE_LLM_MODEL["azure-api"]["api_key"],
|
||||
deployment_name=ONLINE_LLM_MODEL["azure-api"]["deployment_name"],
|
||||
api_version=ONLINE_LLM_MODEL["azure-api"]["api_version"],
|
||||
temperature=TEMPERATURE,
|
||||
max_tokens=1024,
|
||||
model_name="azure-api"
|
||||
):
|
||||
assert resp["code"] == 200
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_azure("azure-api")
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from server.model_workers.baichuan import request_baichuan_api
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
def test_baichuan():
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
|
||||
for x in request_baichuan_api(messages):
|
||||
print(type(x))
|
||||
pprint(x)
|
||||
assert x["code"] == 0
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from server.model_workers.fangzhou import request_volc_api
|
||||
from pprint import pprint
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", ["chatglm-6b-model"])
|
||||
def test_qianfan(version):
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
print("\n" + version + "\n")
|
||||
i = 1
|
||||
for x in request_volc_api(messages, version=version):
|
||||
print(type(x))
|
||||
pprint(x)
|
||||
if chunk := x.choice.message.content:
|
||||
print(chunk)
|
||||
assert x.choice.message
|
||||
i += 1
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from server.model_workers.qianfan import request_qianfan_api, MODEL_VERSIONS
|
||||
from pprint import pprint
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", list(MODEL_VERSIONS.keys())[:2])
|
||||
def test_qianfan(version):
|
||||
messages = [{"role": "user", "content": "你好"}]
|
||||
print("\n" + version + "\n")
|
||||
i = 1
|
||||
for x in request_qianfan_api(messages, version=version):
|
||||
pprint(x)
|
||||
assert isinstance(x, dict)
|
||||
assert "error_code" not in x
|
||||
i += 1
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from server.model_workers.qwen import request_qwen_api
|
||||
from pprint import pprint
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", ["qwen-turbo"])
|
||||
def test_qwen(version):
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
print("\n" + version + "\n")
|
||||
|
||||
for x in request_qwen_api(messages, version=version):
|
||||
print(type(x))
|
||||
pprint(x)
|
||||
assert x["code"] == 200
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
root_path = Path(__file__).parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from configs import ONLINE_LLM_MODEL
|
||||
from server.model_workers.base import *
|
||||
from server.utils import get_model_worker_config, list_config_llm_models
|
||||
from pprint import pprint
|
||||
import pytest
|
||||
|
||||
|
||||
workers = []
|
||||
for x in list_config_llm_models()["online"]:
|
||||
if x in ONLINE_LLM_MODEL and x not in workers:
|
||||
workers.append(x)
|
||||
print(f"all workers to test: {workers}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("worker", workers)
|
||||
def test_chat(worker):
|
||||
params = ApiChatParams(
|
||||
messages = [
|
||||
{"role": "user", "content": "你是谁"},
|
||||
],
|
||||
)
|
||||
print(f"\nchat with {worker} \n")
|
||||
|
||||
worker_class = get_model_worker_config(worker)["worker_class"]
|
||||
for x in worker_class().do_chat(params):
|
||||
pprint(x)
|
||||
assert isinstance(x, dict)
|
||||
assert x["error_code"] == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("worker", workers)
|
||||
def test_embeddings(worker):
|
||||
params = ApiEmbeddingsParams(
|
||||
texts = [
|
||||
"LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。",
|
||||
"一种利用 langchain 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。",
|
||||
]
|
||||
)
|
||||
|
||||
worker_class = get_model_worker_config(worker)["worker_class"]
|
||||
if worker_class.can_embedding():
|
||||
print(f"\embeddings with {worker} \n")
|
||||
resp = worker_class().do_embeddings(params)
|
||||
|
||||
pprint(resp, depth=2)
|
||||
assert resp["code"] == 200
|
||||
assert "embeddings" in resp
|
||||
embeddings = resp["embeddings"]
|
||||
assert isinstance(embeddings, list) and len(embeddings) > 0
|
||||
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
||||
assert isinstance(embeddings[0][0], float)
|
||||
print("向量长度:", len(embeddings[0]))
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("worker", workers)
|
||||
# def test_completion(worker):
|
||||
# params = ApiCompletionParams(prompt="五十六个民族")
|
||||
|
||||
# print(f"\completion with {worker} \n")
|
||||
|
||||
# worker_class = get_model_worker_config(worker)["worker_class"]
|
||||
# resp = worker_class().do_completion(params)
|
||||
# pprint(resp)
|
||||
Loading…
Reference in New Issue