优化在线 API ,支持 completion 和 embedding,简化在线 API 开发方式 (#1886)

* 优化在线 API ,支持 completion 和 embedding,简化在线 API 开发方式

新功能
- 智谱AI、Minimax、千帆、千问 4 个在线模型支持 embeddings(不通过Fastchat,后续会单独提供相关api接口)
- 在线模型自动检测传入参数,在传入非 messages 格式的 prompt 时,自动转换为 completion 形式,以支持 completion 接口

开发者:
- 重构ApiModelWorker:
  - 所有在线 API 请求封装到 do_chat 方法:自动传入参数 ApiChatParams,简化参数与配置项的获取;自动处理与fastchat的接口
  - 加强 API 请求错误处理,返回更有意义的信息
  - 改用 qianfan sdk 重写 qianfan-api
  - 将所有在线模型的测试用例统一在一起,简化测试用例编写

* Delete requirements_langflow.txt
This commit is contained in:
liunux4odoo 2023-10-26 22:44:48 +08:00 committed by GitHub
parent e74fe2d950
commit b4c68ddd05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 689 additions and 633 deletions

View File

@ -26,10 +26,11 @@ pytest
# scikit-learn # scikit-learn
# numexpr # numexpr
# vllm==0.1.7; sys_platform == "linux" # vllm==0.1.7; sys_platform == "linux"
# online api libs # online api libs
zhipuai zhipuai
dashscope>=1.10.0 # qwen dashscope>=1.10.0 # qwen
# qianfan qianfan
# volcengine>=1.0.106 # fangzhou # volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store # uncomment libs if you want to use corresponding vector store

View File

@ -26,7 +26,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
): ):

View File

@ -7,7 +7,7 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable
import asyncio import asyncio
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from typing import List from typing import List, Optional
from server.chat.utils import History from server.chat.utils import History
from server.utils import get_prompt_template from server.utils import get_prompt_template
@ -22,7 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):

View File

@ -4,7 +4,7 @@ from configs import LLM_MODEL, TEMPERATURE
from server.utils import wrap_done, get_OpenAI from server.utils import wrap_done, get_OpenAI
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable, Optional
import asyncio import asyncio
from langchain.prompts.chat import PromptTemplate from langchain.prompts.chat import PromptTemplate
from server.utils import get_prompt_template from server.utils import get_prompt_template
@ -15,7 +15,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
echo: bool = Body(False, description="除了输出之外,还回显输入"), echo: bool = Body(False, description="除了输出之外,还回显输入"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),

View File

@ -31,7 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)

View File

@ -1,5 +1,5 @@
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from typing import List from typing import List, Optional
import openai import openai
from configs import LLM_MODEL, logger, log_verbose from configs import LLM_MODEL, logger, log_verbose
from server.utils import get_model_worker_config, fschat_openai_api_address from server.utils import get_model_worker_config, fschat_openai_api_address
@ -16,7 +16,7 @@ class OpenAiChatMsgIn(BaseModel):
messages: List[OpenAiMessage] messages: List[OpenAiMessage]
temperature: float = 0.7 temperature: float = 0.7
n: int = 1 n: int = 1
max_tokens: int = None max_tokens: Optional[int] = None
stop: List[str] = [] stop: List[str] = []
stream: bool = False stream: bool = False
presence_penalty: int = 0 presence_penalty: int = 0

View File

@ -128,7 +128,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
split_result: bool = Body(False, description="是否对搜索结果进行拆分主要用于metaphor搜索引擎") split_result: bool = Body(False, description="是否对搜索结果进行拆分主要用于metaphor搜索引擎")
): ):

View File

@ -2,9 +2,8 @@ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import Document
import threading import threading
from configs import (EMBEDDING_MODEL, CHUNK_SIZE, CACHED_VS_NUM, from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
logger, log_verbose) logger, log_verbose)
from server.utils import embedding_device, get_model_path from server.utils import embedding_device, get_model_path
from contextlib import contextmanager from contextlib import contextmanager
@ -122,7 +121,9 @@ class EmbeddingsPool(CachePool):
with item.acquire(msg="初始化"): with item.acquire(msg="初始化"):
self.atomic.release() self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002 if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=get_model_path(model), chunk_size=CHUNK_SIZE) embeddings = OpenAIEmbeddings(model_name=model, # TODO: 支持Azure
openai_api_key=get_model_path(model),
chunk_size=CHUNK_SIZE)
elif 'bge-' in model: elif 'bge-' in model:
if 'zh' in model: if 'zh' in model:
# for chinese model # for chinese model

View File

@ -1,3 +1,4 @@
from .base import *
from .zhipu import ChatGLMWorker from .zhipu import ChatGLMWorker
from .minimax import MiniMaxWorker from .minimax import MiniMaxWorker
from .xinghuo import XingHuoWorker from .xinghuo import XingHuoWorker

View File

@ -1,46 +1,12 @@
import sys import sys
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import *
from server.utils import get_model_worker_config, get_httpx_client from server.utils import get_httpx_client
from fastchat import conversation as conv from fastchat import conversation as conv
import json import json
from typing import List, Dict from typing import List, Dict
from configs import TEMPERATURE
def request_azure_api(
messages: List[Dict[str, str]],
resource_name: str = None,
api_key: str = None,
deployment_name: str = None,
api_version: str = "2023-07-01-preview",
temperature: float = TEMPERATURE,
max_tokens: int = None,
model_name: str = "azure-api",
):
config = get_model_worker_config(model_name)
data = dict(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
)
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}"
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'api-key': api_key,
}
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
if not line.strip() or "[DONE]" in line:
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
yield resp
class AzureWorker(ApiModelWorker): class AzureWorker(ApiModelWorker):
def __init__( def __init__(
self, self,
@ -54,37 +20,40 @@ class AzureWorker(ApiModelWorker):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8192) kwargs.setdefault("context_len", 8192)
super().__init__(**kwargs) super().__init__(**kwargs)
config = self.get_config()
self.resource_name = config.get("resource_name")
self.api_key = config.get("api_key")
self.api_version = config.get("api_version")
self.version = version self.version = version
self.deployment_name = config.get("deployment_name")
def generate_stream_gate(self, params): def do_chat(self, params: ApiChatParams) -> Dict:
super().generate_stream_gate(params) params.load_config(self.model_names[0])
data = dict(
messages=params.messages,
temperature=params.temperature,
max_tokens=params.max_tokens,
stream=True,
)
url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}"
.format(params.resource_name, params.deployment_name, params.api_version))
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'api-key': params.api_key,
}
messages = self.prompt_to_messages(params["prompt"])
text = "" text = ""
for resp in request_azure_api(messages=messages, with get_httpx_client() as client:
resource_name=self.resource_name, with client.stream("POST", url, headers=headers, json=data) as response:
api_key=self.api_key, for line in response.iter_lines():
api_version=self.api_version, if not line.strip() or "[DONE]" in line:
deployment_name=self.deployment_name, continue
temperature=params.get("temperature"), if line.startswith("data: "):
max_tokens=params.get("max_tokens")): line = line[6:]
if choices := resp["choices"]: resp = json.loads(line)
if chunk := choices[0].get("delta", {}).get("content"): if choices := resp["choices"]:
text += chunk if chunk := choices[0].get("delta", {}).get("content"):
yield json.dumps( text += chunk
{ yield {
"error_code": 0, "error_code": 0,
"text": text "text": text
}, }
ensure_ascii=False
).encode() + b"\0"
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings
@ -95,7 +64,7 @@ class AzureWorker(ApiModelWorker):
# TODO: 确认模板是否需要修改 # TODO: 确认模板是否需要修改
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="", system_message="You are a helpful, respectful and honest assistant.",
messages=[], messages=[],
roles=["user", "assistant"], roles=["user", "assistant"],
sep="\n### ", sep="\n### ",

View File

@ -1,18 +1,14 @@
# import os
# import sys
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import json import json
import time import time
import hashlib import hashlib
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import *
from server.utils import get_model_worker_config, get_httpx_client from server.utils import get_httpx_client
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
import json import json
from typing import List, Literal, Dict from typing import List, Literal, Dict
from configs import TEMPERATURE
def calculate_md5(input_string): def calculate_md5(input_string):
@ -22,56 +18,12 @@ def calculate_md5(input_string):
return encrypted return encrypted
def request_baichuan_api(
messages: List[Dict[str, str]],
api_key: str = None,
secret_key: str = None,
version: str = "Baichuan2-53B",
temperature: float = TEMPERATURE,
model_name: str = "baichuan-api",
):
config = get_model_worker_config(model_name)
api_key = api_key or config.get("api_key")
secret_key = secret_key or config.get("secret_key")
version = version or config.get("version")
url = "https://api.baichuan-ai.com/v1/stream/chat"
data = {
"model": version,
"messages": messages,
"parameters": {"temperature": temperature}
}
json_data = json.dumps(data)
time_stamp = int(time.time())
signature = calculate_md5(secret_key + json_data + str(time_stamp))
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + api_key,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(time_stamp),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",
}
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
if not line.strip():
continue
resp = json.loads(line)
yield resp
class BaiChuanWorker(ApiModelWorker): class BaiChuanWorker(ApiModelWorker):
BASE_URL = "https://api.baichuan-ai.com/v1/stream/chat"
SUPPORT_MODELS = ["Baichuan2-53B"]
def __init__( def __init__(
self, self,
*, *,
controller_addr: str, controller_addr: str = None,
worker_addr: str, worker_addr: str = None,
model_names: List[str] = ["baichuan-api"], model_names: List[str] = ["baichuan-api"],
version: Literal["Baichuan2-53B"] = "Baichuan2-53B", version: Literal["Baichuan2-53B"] = "Baichuan2-53B",
**kwargs, **kwargs,
@ -79,40 +31,48 @@ class BaiChuanWorker(ApiModelWorker):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768) kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version
config = self.get_config() def do_chat(self, params: ApiChatParams) -> Dict:
self.version = config.get("version",version) params.load_config(self.model_names[0])
self.api_key = config.get("api_key")
self.secret_key = config.get("secret_key")
def generate_stream_gate(self, params): url = "https://api.baichuan-ai.com/v1/stream/chat"
super().generate_stream_gate(params) data = {
"model": params.version,
"messages": params.messages,
"parameters": {"temperature": params.temperature}
}
messages = self.prompt_to_messages(params["prompt"]) json_data = json.dumps(data)
time_stamp = int(time.time())
signature = calculate_md5(params.secret_key + json_data + str(time_stamp))
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + params.api_key,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(time_stamp),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",
}
text = "" text = ""
for resp in request_baichuan_api(messages=messages, with get_httpx_client() as client:
api_key=self.api_key, with client.stream("POST", url, headers=headers, json=data) as response:
secret_key=self.secret_key, for line in response.iter_lines():
version=self.version, if not line.strip():
temperature=params.get("temperature")): continue
if resp["code"] == 0: resp = json.loads(line)
text += resp["data"]["messages"][-1]["content"] if resp["code"] == 0:
yield json.dumps( text += resp["data"]["messages"][-1]["content"]
{ yield {
"error_code": resp["code"], "error_code": resp["code"],
"text": text "text": text
}, }
ensure_ascii=False else:
).encode() + b"\0" yield {
else: "error_code": resp["code"],
yield json.dumps( "text": resp["msg"]
{ }
"error_code": resp["code"],
"text": resp["msg"]
},
ensure_ascii=False
).encode() + b"\0"
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings

View File

@ -1,36 +1,106 @@
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
from configs.basic_config import LOG_PATH from configs import LOG_PATH, TEMPERATURE
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.base_model_worker import BaseModelWorker from fastchat.serve.base_model_worker import BaseModelWorker
import uuid import uuid
import json import json
import sys import sys
from pydantic import BaseModel from pydantic import BaseModel, root_validator
import fastchat import fastchat
import asyncio import asyncio
from typing import Dict, List from server.utils import get_model_worker_config
from typing import Dict, List, Optional
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
# 恢复被fastchat覆盖的标准输出 # 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__ sys.stderr = sys.__stderr__
class ApiModelOutMsg(BaseModel): class ApiConfigParams(BaseModel):
error_code: int = 0 '''
text: str 在线API配置参数未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取
'''
api_base_url: Optional[str] = None
api_proxy: Optional[str] = None
api_key: Optional[str] = None
secret_key: Optional[str] = None
group_id: Optional[str] = None # for minimax
is_pro: bool = False # for minimax
APPID: Optional[str] = None # for xinghuo
APISecret: Optional[str] = None # for xinghuo
is_v2: bool = False # for xinghuo
worker_name: Optional[str] = None
class Config:
extra = "allow"
@root_validator(pre=True)
def validate_config(cls, v: Dict) -> Dict:
if config := get_model_worker_config(v.get("worker_name")):
for n in cls.__fields__:
if n in config:
v[n] = config[n]
return v
def load_config(self, worker_name: str):
self.worker_name = worker_name
if config := get_model_worker_config(worker_name):
for n in self.__fields__:
if n in config:
setattr(self, n, config[n])
return self
class ApiModelParams(ApiConfigParams):
'''
模型配置参数
'''
version: Optional[str] = None
version_url: Optional[str] = None
api_version: Optional[str] = None # for azure
deployment_name: Optional[str] = None # for azure
resource_name: Optional[str] = None # for azure
temperature: float = TEMPERATURE
max_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
class ApiChatParams(ApiModelParams):
'''
chat请求参数
'''
messages: List[Dict[str, str]]
system_message: Optional[str] = None # for minimax
role_meta: Dict = {} # for minimax
class ApiCompletionParams(ApiModelParams):
prompt: str
class ApiEmbeddingsParams(ApiConfigParams):
texts: List[str]
embed_model: Optional[str] = None
to_query: bool = False # for minimax
class ApiModelWorker(BaseModelWorker): class ApiModelWorker(BaseModelWorker):
BASE_URL: str DEFAULT_EMBED_MODEL: str = None # None means not support embedding
SUPPORT_MODELS: List
def __init__( def __init__(
self, self,
model_names: List[str], model_names: List[str],
controller_addr: str, controller_addr: str = None,
worker_addr: str, worker_addr: str = None,
context_len: int = 2048, context_len: int = 2048,
no_register: bool = False,
**kwargs, **kwargs,
): ):
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8]) kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
@ -40,9 +110,18 @@ class ApiModelWorker(BaseModelWorker):
controller_addr=controller_addr, controller_addr=controller_addr,
worker_addr=worker_addr, worker_addr=worker_addr,
**kwargs) **kwargs)
import sys
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
self.context_len = context_len self.context_len = context_len
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency) self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
self.init_heart_beat() self.version = None
if not no_register and self.controller_addr:
self.init_heart_beat()
def count_token(self, params): def count_token(self, params):
# TODO需要完善 # TODO需要完善
@ -50,33 +129,108 @@ class ApiModelWorker(BaseModelWorker):
prompt = params["prompt"] prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0} return {"count": len(str(prompt)), "error_code": 0}
def generate_stream_gate(self, params): def generate_stream_gate(self, params: Dict):
self.call_ct += 1 self.call_ct += 1
try:
prompt = params["prompt"]
if self._is_chat(prompt):
messages = self.prompt_to_messages(prompt)
messages = self.validate_messages(messages)
else: # 使用chat模仿续写功能不支持历史消息
messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]
p = ApiChatParams(
messages=messages,
temperature=params.get("temperature"),
top_p=params.get("top_p"),
max_tokens=params.get("max_new_tokens"),
version=self.version,
)
for resp in self.do_chat(p):
yield self._jsonify(resp)
except Exception as e:
yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误{e}"})
def generate_gate(self, params): def generate_gate(self, params):
for x in self.generate_stream_gate(params): try:
pass for x in self.generate_stream_gate(params):
return json.loads(x[:-1].decode()) ...
return json.loads(x[:-1].decode())
except Exception as e:
return {"error_code": 500, "text": str(e)}
# 需要用户自定义的方法
def do_chat(self, params: ApiChatParams) -> Dict:
'''
执行Chat的方法默认使用模块里面的chat函数
要求返回形式{"error_code": int, "text": str}
'''
return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}
# def do_completion(self, p: ApiCompletionParams) -> Dict:
# '''
# 执行Completion的方法默认使用模块里面的completion函数。
# 要求返回形式:{"error_code": int, "text": str}
# '''
# return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
'''
执行Embeddings的方法默认使用模块里面的embed_documents函数
要求返回形式{"code": int, "embeddings": List[List[float]], "msg": str}
'''
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
def get_embeddings(self, params): def get_embeddings(self, params):
print("embedding") # fastchat对LLM做Embeddings限制很大似乎只能使用openai的。
# print(params) # 在前端通过OpenAIEmbeddings发起的请求直接出错无法请求过来。
print("get_embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
raise NotImplementedError raise NotImplementedError
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
'''
有些API对mesages有特殊格式可以重写该函数替换默认的messages
之所以跟prompt_to_messages分开是因为他们应用场景不同参数不同
'''
return messages
# help methods # help methods
def get_config(self): @property
from server.utils import get_model_worker_config def user_role(self):
return get_model_worker_config(self.model_names[0]) return self.conv.roles[0]
@property
def ai_role(self):
return self.conv.roles[1]
def _jsonify(self, data: Dict) -> str:
'''
将chat函数返回的结果按照fastchat openai-api-server的格式返回
'''
return json.dumps(data, ensure_ascii=False).encode() + b"\0"
def _is_chat(self, prompt: str) -> bool:
'''
检查prompt是否由chat messages拼接而来
TODO: 存在误判的可能也许从fastchat直接传入原始messages是更好的做法
'''
key = f"{self.conv.sep}{self.user_role}:"
return key in prompt
def prompt_to_messages(self, prompt: str) -> List[Dict]: def prompt_to_messages(self, prompt: str) -> List[Dict]:
''' '''
将prompt字符串拆分成messages. 将prompt字符串拆分成messages.
''' '''
result = [] result = []
user_role = self.conv.roles[0] user_role = self.user_role
ai_role = self.conv.roles[1] ai_role = self.ai_role
user_start = user_role + ":" user_start = user_role + ":"
ai_start = ai_role + ":" ai_start = ai_role + ":"
for msg in prompt.split(self.conv.sep)[1:-1]: for msg in prompt.split(self.conv.sep)[1:-1]:
@ -89,3 +243,7 @@ class ApiModelWorker(BaseModelWorker):
else: else:
raise RuntimeError(f"unknown role in msg: {msg}") raise RuntimeError(f"unknown role in msg: {msg}")
return result return result
@classmethod
def can_embedding(cls):
return cls.DEFAULT_EMBED_MODEL is not None

View File

@ -1,98 +1,58 @@
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import *
from configs.model_config import TEMPERATURE
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
import json
from pprint import pprint
from server.utils import get_model_worker_config
from typing import List, Literal, Dict from typing import List, Literal, Dict
def request_volc_api(
messages: List[Dict],
model_name: str = "fangzhou-api",
version: str = "chatglm-6b-model",
temperature: float = TEMPERATURE,
api_key: str = None,
secret_key: str = None,
):
from volcengine.maas import MaasService, MaasException, ChatRole
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
config = get_model_worker_config(model_name)
version = version or config.get("version")
version_url = config.get("version_url")
api_key = api_key or config.get("api_key")
secret_key = secret_key or config.get("secret_key")
maas.set_ak(api_key)
maas.set_sk(secret_key)
# document: "https://www.volcengine.com/docs/82379/1099475"
req = {
"model": {
"name": version,
},
"parameters": {
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
"max_new_tokens": 1000,
"temperature": temperature,
},
"messages": messages,
}
try:
resps = maas.stream_chat(req)
for resp in resps:
yield resp
except MaasException as e:
print(e)
class FangZhouWorker(ApiModelWorker): class FangZhouWorker(ApiModelWorker):
""" """
火山方舟 火山方舟
""" """
SUPPORT_MODELS = ["chatglm-6b-model"]
def __init__( def __init__(
self, self,
*, *,
version: Literal["chatglm-6b-model"] = "chatglm-6b-model", model_names: List[str] = ["fangzhou-api"],
model_names: List[str] = ["fangzhou-api"], controller_addr: str = None,
controller_addr: str, worker_addr: str = None,
worker_addr: str, version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小 kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小
super().__init__(**kwargs) super().__init__(**kwargs)
config = self.get_config()
self.version = version self.version = version
self.api_key = config.get("api_key")
self.secret_key = config.get("secret_key")
def generate_stream_gate(self, params): def do_chat(self, params: ApiChatParams) -> Dict:
super().generate_stream_gate(params) from volcengine.maas import MaasService
params.load_config(self.model_names[0])
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
maas.set_ak(params.api_key)
maas.set_sk(params.secret_key)
# document: "https://www.volcengine.com/docs/82379/1099475"
req = {
"model": {
"name": params.version,
},
"parameters": {
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
"max_new_tokens": params.max_tokens,
"temperature": params.temperature,
},
"messages": params.messages,
}
messages = self.prompt_to_messages(params["prompt"])
text = "" text = ""
for resp in maas.stream_chat(req):
for resp in request_volc_api(messages=messages,
model_name=self.model_names[0],
version=self.version,
temperature=params.get("temperature", TEMPERATURE),
):
error = resp.error error = resp.error
if error.code_n > 0: if error.code_n > 0:
data = {"error_code": error.code_n, "text": error.message} yield {"error_code": error.code_n, "text": error.message}
elif chunk := resp.choice.message.content: elif chunk := resp.choice.message.content:
text += chunk text += chunk
data = {"error_code": 0, "text": text} yield {"error_code": 0, "text": text}
yield json.dumps(data, ensure_ascii=False).encode() + b"\0"
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings

View File

@ -1,77 +1,108 @@
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import *
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
import json import json
from server.model_workers.base import ApiEmbeddingsParams
from server.utils import get_httpx_client from server.utils import get_httpx_client
from pprint import pprint
from typing import List, Dict from typing import List, Dict
class MiniMaxWorker(ApiModelWorker): class MiniMaxWorker(ApiModelWorker):
BASE_URL = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}' DEFAULT_EMBED_MODEL = "embo-01"
def __init__( def __init__(
self, self,
*, *,
model_names: List[str] = ["minimax-api"], model_names: List[str] = ["minimax-api"],
controller_addr: str, controller_addr: str = None,
worker_addr: str, worker_addr: str = None,
version: str = "abab5.5-chat",
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384) kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version
def prompt_to_messages(self, prompt: str) -> List[Dict]: def validate_messages(self, messages: List[Dict]) -> List[Dict]:
result = super().prompt_to_messages(prompt) role_maps = {
messages = [{"sender_type": x["role"], "text": x["content"]} for x in result] "user": self.user_role,
"assistant": self.ai_role,
"system": "system",
}
messages = [{"sender_type": role_maps[x["role"]], "text": x["content"]} for x in messages]
return messages return messages
def generate_stream_gate(self, params): def do_chat(self, params: ApiChatParams) -> Dict:
# 按照官网推荐直接调用abab 5.5模型 # 按照官网推荐直接调用abab 5.5模型
# TODO: 支持指定回复要求支持指定用户名称、AI名称 # TODO: 支持指定回复要求支持指定用户名称、AI名称
params.load_config(self.model_names[0])
super().generate_stream_gate(params) url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
config = self.get_config() pro = "_pro" if params.is_pro else ""
group_id = config.get("group_id")
api_key = config.get("api_key")
pro = "_pro" if config.get("is_pro") else ""
headers = { headers = {
"Authorization": f"Bearer {api_key}", "Authorization": f"Bearer {params.api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
messages = self.validate_messages(params.messages)
data = { data = {
"model": "abab5.5-chat", "model": params.version,
"stream": True, "stream": True,
"tokens_to_generate": 1024, # TODO: 1024为官网默认值
"mask_sensitive_info": True, "mask_sensitive_info": True,
"messages": self.prompt_to_messages(params["prompt"]), "messages": messages,
"temperature": params.get("temperature"), "temperature": params.temperature,
"top_p": params.get("top_p"), "top_p": params.top_p,
"bot_setting": [], "tokens_to_generate": params.max_tokens or 1024,
# TODO: 以下参数为minimax特有传入空值会出错。
# "prompt": params.system_message or self.conv.system_message,
# "bot_setting": [],
# "role_meta": params.role_meta,
} }
print("request data sent to minimax:")
pprint(data)
with get_httpx_client() as client: with get_httpx_client() as client:
response = client.stream("POST", response = client.stream("POST",
self.BASE_URL.format(pro=pro, group_id=group_id), url.format(pro=pro, group_id=params.group_id),
headers=headers, headers=headers,
json=data) json=data)
with response as r: with response as r:
text = "" text = ""
for e in r.iter_text(): for e in r.iter_text():
if e.startswith("data: "): # 真是优秀的返回 if not e.startswith("data: "): # 真是优秀的返回
data = json.loads(e[6:]) yield {"error_code": 500, "text": f"minimax返回错误的结果{e}"}
if not data.get("usage"): continue
if choices := data.get("choices"):
chunk = choices[0].get("delta", "").strip() data = json.loads(e[6:])
if chunk: if data.get("usage"):
print(chunk) break
text += chunk
yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0" if choices := data.get("choices"):
if chunk := choices[0].get("delta", ""):
text += chunk
yield {"error_code": 0, "text": text}
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
params.load_config(self.model_names[0])
url = f"https://api.minimax.chat/v1/embeddings?GroupId={params.group_id}"
headers = {
"Authorization": f"Bearer {params.api_key}",
"Content-Type": "application/json",
}
data = {
"model": params.embed_model or self.DEFAULT_EMBED_MODEL,
"texts": params.texts,
"type": "query" if params.to_query else "db",
}
with get_httpx_client() as client:
r = client.post(url, headers=headers, json=data).json()
if embeddings := r.get("vectors"):
return {"code": 200, "embeddings": embeddings}
elif error := r.get("base_resp"):
return {"code": error["status_code"], "msg": error["status_msg"]}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings
print("embedding") print("embedding")
@ -81,7 +112,7 @@ class MiniMaxWorker(ApiModelWorker):
# TODO: 确认模板是否需要修改 # TODO: 确认模板是否需要修改
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="", system_message="你是MiniMax自主研发的大型语言模型回答问题简洁有条理。",
messages=[], messages=[],
roles=["USER", "BOT"], roles=["USER", "BOT"],
sep="\n### ", sep="\n### ",

View File

@ -1,155 +1,169 @@
import sys import sys
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import *
from configs.model_config import TEMPERATURE
from fastchat import conversation as conv from fastchat import conversation as conv
import json import sys
from cachetools import cached, TTLCache from server.model_workers.base import ApiEmbeddingsParams
from server.utils import get_model_worker_config, get_httpx_client
from typing import List, Literal, Dict from typing import List, Literal, Dict
MODEL_VERSIONS = {
"ernie-bot": "completions",
"ernie-bot-4": "completions_pro",
"ernie-bot-turbo": "eb-instant",
"bloomz-7b": "bloomz_7b1",
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
"llama2-7b-chat": "llama_2_7b",
"llama2-13b-chat": "llama_2_13b",
"llama2-70b-chat": "llama_2_70b",
"qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b",
"chatglm2-6b-32k": "chatglm2_6b_32k",
"aquilachat-7b": "aquilachat_7b",
# "linly-llama2-ch-7b": "", # 暂未发布
# "linly-llama2-ch-13b": "", # 暂未发布
# "chatglm2-6b": "", # 暂未发布
# "chatglm2-6b-int4": "", # 暂未发布
# "falcon-7b": "", # 暂未发布
# "falcon-180b-chat": "", # 暂未发布
# "falcon-40b": "", # 暂未发布
# "rwkv4-world": "", # 暂未发布
# "rwkv5-world": "", # 暂未发布
# "rwkv4-pile-14b": "", # 暂未发布
# "rwkv4-raven-14b": "", # 暂未发布
# "open-llama-7b": "", # 暂未发布
# "dolly-12b": "", # 暂未发布
# "mpt-7b-instruct": "", # 暂未发布
# "mpt-30b-instruct": "", # 暂未发布
# "OA-Pythia-12B-SFT-4": "", # 暂未发布
# "xverse-13b": "", # 暂未发布
# # 以下为企业测试,需要单独申请 # MODEL_VERSIONS = {
# "flan-ul2": "", # "ernie-bot": "completions",
# "Cerebras-GPT-6.7B": "" # "ernie-bot-turbo": "eb-instant",
# "Pythia-6.9B": "" # "bloomz-7b": "bloomz_7b1",
} # "qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
# "llama2-7b-chat": "llama_2_7b",
# "llama2-13b-chat": "llama_2_13b",
# "llama2-70b-chat": "llama_2_70b",
# "qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b",
# "chatglm2-6b-32k": "chatglm2_6b_32k",
# "aquilachat-7b": "aquilachat_7b",
# # "linly-llama2-ch-7b": "", # 暂未发布
# # "linly-llama2-ch-13b": "", # 暂未发布
# # "chatglm2-6b": "", # 暂未发布
# # "chatglm2-6b-int4": "", # 暂未发布
# # "falcon-7b": "", # 暂未发布
# # "falcon-180b-chat": "", # 暂未发布
# # "falcon-40b": "", # 暂未发布
# # "rwkv4-world": "", # 暂未发布
# # "rwkv5-world": "", # 暂未发布
# # "rwkv4-pile-14b": "", # 暂未发布
# # "rwkv4-raven-14b": "", # 暂未发布
# # "open-llama-7b": "", # 暂未发布
# # "dolly-12b": "", # 暂未发布
# # "mpt-7b-instruct": "", # 暂未发布
# # "mpt-30b-instruct": "", # 暂未发布
# # "OA-Pythia-12B-SFT-4": "", # 暂未发布
# # "xverse-13b": "", # 暂未发布
# # # 以下为企业测试,需要单独申请
# # "flan-ul2": "",
# # "Cerebras-GPT-6.7B": ""
# # "Pythia-6.9B": ""
# }
@cached(TTLCache(1, 1800)) # 经过测试缓存的token可以使用目前每30分钟刷新一次 # @cached(TTLCache(1, 1800)) # 经过测试缓存的token可以使用目前每30分钟刷新一次
def get_baidu_access_token(api_key: str, secret_key: str) -> str: # def get_baidu_access_token(api_key: str, secret_key: str) -> str:
""" # """
使用 AKSK 生成鉴权签名Access Token # 使用 AKSK 生成鉴权签名Access Token
:return: access_token或是None(如果错误) # :return: access_token或是None(如果错误)
""" # """
url = "https://aip.baidubce.com/oauth/2.0/token" # url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} # params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
try: # try:
with get_httpx_client() as client: # with get_httpx_client() as client:
return client.get(url, params=params).json().get("access_token") # return client.get(url, params=params).json().get("access_token")
except Exception as e: # except Exception as e:
print(f"failed to get token from baidu: {e}") # print(f"failed to get token from baidu: {e}")
def request_qianfan_api(
messages: List[Dict[str, str]],
temperature: float = TEMPERATURE,
model_name: str = "qianfan-api",
version: str = None,
) -> Dict:
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
'/{model_version}?access_token={access_token}'
config = get_model_worker_config(model_name)
version = version or config.get("version")
version_url = config.get("version_url")
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
if not access_token:
yield {
"error_code": 403,
"error_msg": f"failed to get access token. have you set the correct api_key and secret key?",
}
url = BASE_URL.format(
model_version=version_url or MODEL_VERSIONS[version],
access_token=access_token,
)
payload = {
"messages": messages,
"temperature": temperature,
"stream": True
}
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
}
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=payload) as response:
for line in response.iter_lines():
if not line.strip():
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
breakpoint()
yield resp
class QianFanWorker(ApiModelWorker): class QianFanWorker(ApiModelWorker):
""" """
百度千帆 百度千帆
""" """
DEFAULT_EMBED_MODEL = "bge-large-zh"
def __init__( def __init__(
self, self,
*, *,
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot", version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
model_names: List[str] = ["ernie-api"], model_names: List[str] = ["qianfan-api"],
controller_addr: str, controller_addr: str = None,
worker_addr: str, worker_addr: str = None,
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384) kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs) super().__init__(**kwargs)
config = self.get_config()
self.version = version self.version = version
self.api_key = config.get("api_key")
self.secret_key = config.get("secret_key")
def generate_stream_gate(self, params): def do_chat(self, params: ApiChatParams) -> Dict:
messages = self.prompt_to_messages(params["prompt"]) params.load_config(self.model_names[0])
import qianfan
comp = qianfan.ChatCompletion(model=params.version,
endpoint=params.version_url,
ak=params.api_key,
sk=params.secret_key,)
text = "" text = ""
for resp in request_qianfan_api(messages, for resp in comp.do(messages=params.messages,
temperature=params.get("temperature"), temperature=params.temperature,
model_name=self.model_names[0]): top_p=params.top_p,
if "result" in resp.keys(): stream=True):
text += resp["result"] if resp.code == 200:
yield json.dumps({ if chunk := resp.body.get("result"):
"error_code": 0, text += chunk
"text": text yield {
}, "error_code": 0,
ensure_ascii=False "text": text
).encode() + b"\0" }
else: else:
yield json.dumps({ yield {
"error_code": resp["error_code"], "error_code": resp.code,
"text": resp["error_msg"] "text": str(resp.body),
}, }
ensure_ascii=False
).encode() + b"\0"
# BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\
# '/{model_version}?access_token={access_token}'
# access_token = get_baidu_access_token(params.api_key, params.secret_key)
# if not access_token:
# yield {
# "error_code": 403,
# "text": f"failed to get access token. have you set the correct api_key and secret key?",
# }
# url = BASE_URL.format(
# model_version=params.version_url or MODEL_VERSIONS[params.version],
# access_token=access_token,
# )
# payload = {
# "messages": params.messages,
# "temperature": params.temperature,
# "stream": True
# }
# headers = {
# 'Content-Type': 'application/json',
# 'Accept': 'application/json',
# }
# text = ""
# with get_httpx_client() as client:
# with client.stream("POST", url, headers=headers, json=payload) as response:
# for line in response.iter_lines():
# if not line.strip():
# continue
# if line.startswith("data: "):
# line = line[6:]
# resp = json.loads(line)
# if "result" in resp.keys():
# text += resp["result"]
# yield {
# "error_code": 0,
# "text": text
# }
# else:
# yield {
# "error_code": resp["error_code"],
# "text": resp["error_msg"]
# }
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import qianfan
params.load_config(self.model_names[0])
embed = qianfan.Embedding(ak=params.api_key, sk=params.secret_key)
resp = embed.do(texts = params.texts, model=params.embed_model or self.DEFAULT_EMBED_MODEL)
if resp.code == 200:
embeddings = [x.embedding for x in resp.body.get("data", [])]
return {"code": 200, "embeddings": embeddings}
else:
return {"code": resp.code, "msg": str(resp.body)}
# TODO: qianfan支持续写模型
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings
print("embedding") print("embedding")
@ -159,7 +173,7 @@ class QianFanWorker(ApiModelWorker):
# TODO: 确认模板是否需要修改 # TODO: 确认模板是否需要修改
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="", system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
messages=[], messages=[],
roles=["user", "assistant"], roles=["user", "assistant"],
sep="\n### ", sep="\n### ",

View File

@ -7,94 +7,68 @@ from http import HTTPStatus
from typing import List, Literal, Dict from typing import List, Literal, Dict
from fastchat import conversation as conv from fastchat import conversation as conv
from server.model_workers.base import *
from server.model_workers.base import ApiModelWorker from server.model_workers.base import ApiEmbeddingsParams
from server.utils import get_model_worker_config
def request_qwen_api(
messages: List[Dict[str, str]],
api_key: str = None,
version: str = "qwen-turbo",
temperature: float = TEMPERATURE,
model_name: str = "qwen-api",
):
import dashscope
config = get_model_worker_config(model_name)
api_key = api_key or config.get("api_key")
version = version or config.get("version")
gen = dashscope.Generation()
responses = gen.call(
model=version,
temperature=temperature,
api_key=api_key,
messages=messages,
result_format='message', # set the result is message format.
stream=True,
)
text = ""
for resp in responses:
if resp.status_code != HTTPStatus.OK:
yield {
"code": resp.status_code,
"text": "api not response correctly",
}
if resp["status_code"] == 200:
if choices := resp["output"]["choices"]:
yield {
"code": 200,
"text": choices[0]["message"]["content"],
}
else:
yield {
"code": resp["status_code"],
"text": resp["message"],
}
class QwenWorker(ApiModelWorker): class QwenWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "text-embedding-v1"
def __init__( def __init__(
self, self,
*, *,
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo", version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
model_names: List[str] = ["qwen-api"], model_names: List[str] = ["qwen-api"],
controller_addr: str, controller_addr: str = None,
worker_addr: str, worker_addr: str = None,
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384) kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs) super().__init__(**kwargs)
config = self.get_config()
self.api_key = config.get("api_key")
self.version = version self.version = version
def generate_stream_gate(self, params): def do_chat(self, params: ApiChatParams) -> Dict:
messages = self.prompt_to_messages(params["prompt"]) import dashscope
params.load_config(self.model_names[0])
for resp in request_qwen_api(messages=messages, gen = dashscope.Generation()
api_key=self.api_key, responses = gen.call(
version=self.version, model=params.version,
temperature=params.get("temperature")): temperature=params.temperature,
if resp["code"] == 200: api_key=params.api_key,
yield json.dumps({ messages=params.messages,
"error_code": 0, result_format='message', # set the result is message format.
"text": resp["text"] stream=True,
}, )
ensure_ascii=False
).encode() + b"\0" for resp in responses:
if resp["status_code"] == 200:
if choices := resp["output"]["choices"]:
yield {
"error_code": 0,
"text": choices[0]["message"]["content"],
}
else: else:
yield json.dumps({ yield {
"error_code": resp["code"], "error_code": resp["status_code"],
"text": resp["text"] "text": resp["message"],
}, }
ensure_ascii=False
).encode() + b"\0" def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import dashscope
params.load_config(self.model_names[0])
resp = dashscope.TextEmbedding.call(
model=params.embed_model or self.DEFAULT_EMBED_MODEL,
input=params.texts, # 最大25行
api_key=params.api_key,
)
if resp["status_code"] != 200:
return {"code": resp["status_code"], "msg": resp.message}
else:
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
return {"code": 200, "embeddings": embeddings}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings

View File

@ -1,12 +1,12 @@
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import *
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
import json import json
from server.model_workers import SparkApi from server.model_workers import SparkApi
import websockets import websockets
from server.utils import iter_over_async, asyncio from server.utils import iter_over_async, asyncio
from typing import List from typing import List, Dict
async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature): async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature):
@ -31,46 +31,40 @@ class XingHuoWorker(ApiModelWorker):
self, self,
*, *,
model_names: List[str] = ["xinghuo-api"], model_names: List[str] = ["xinghuo-api"],
controller_addr: str, controller_addr: str = None,
worker_addr: str, worker_addr: str = None,
version: str = None,
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8192) kwargs.setdefault("context_len", 8192)
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version
def generate_stream_gate(self, params): def do_chat(self, params: ApiChatParams) -> Dict:
# TODO: 当前每次对话都要重新连接websocket确认是否可以保持连接 # TODO: 当前每次对话都要重新连接websocket确认是否可以保持连接
params.load_config(self.model_names[0])
super().generate_stream_gate(params) if params.is_v2:
config = self.get_config()
appid = config.get("APPID")
api_secret = config.get("APISecret")
api_key = config.get("api_key")
if config.get("is_v2"):
domain = "generalv2" # v2.0版本 domain = "generalv2" # v2.0版本
Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址 Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
else: else:
domain = "general" # v1.5版本 domain = "general" # v1.5版本
Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址 Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
question = self.prompt_to_messages(params["prompt"])
text = "" text = ""
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
except: except:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
for chunk in iter_over_async( for chunk in iter_over_async(
request(appid, api_key, api_secret, Spark_url, domain, question, params.get("temperature")), request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages, params.temperature),
loop=loop, loop=loop,
): ):
if chunk: if chunk:
print(chunk)
text += chunk text += chunk
yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0" yield {"error_code": 0, "text": text}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings
@ -81,7 +75,7 @@ class XingHuoWorker(ApiModelWorker):
# TODO: 确认模板是否需要修改 # TODO: 确认模板是否需要修改
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="", system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
messages=[], messages=[],
roles=["user", "assistant"], roles=["user", "assistant"],
sep="\n### ", sep="\n### ",

View File

@ -1,22 +1,20 @@
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import *
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
import json from typing import List, Dict, Iterator, Literal
from typing import List, Literal
class ChatGLMWorker(ApiModelWorker): class ChatGLMWorker(ApiModelWorker):
BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api" DEFAULT_EMBED_MODEL = "text_embedding"
SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"]
def __init__( def __init__(
self, self,
*, *,
model_names: List[str] = ["zhipu-api"], model_names: List[str] = ["zhipu-api"],
controller_addr: str = None,
worker_addr: str = None,
version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std", version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std",
controller_addr: str,
worker_addr: str,
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
@ -24,27 +22,45 @@ class ChatGLMWorker(ApiModelWorker):
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version self.version = version
def generate_stream_gate(self, params): def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
# TODO: 维护request_id # TODO: 维护request_id
import zhipuai import zhipuai
super().generate_stream_gate(params) params.load_config(self.model_names[0])
zhipuai.api_key = self.get_config().get("api_key") zhipuai.api_key = params.api_key
response = zhipuai.model_api.sse_invoke( response = zhipuai.model_api.sse_invoke(
model=self.version, model=params.version,
prompt=[{"role": "user", "content": params["prompt"]}], prompt=params.messages,
temperature=params.get("temperature"), temperature=params.temperature,
top_p=params.get("top_p"), top_p=params.top_p,
incremental=False, incremental=False,
) )
for e in response.events(): for e in response.events():
if e.event == "add": if e.event == "add":
yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0" yield {"error_code": 0, "text": e.data}
# TODO: 更健壮的消息处理 elif e.event in ["error", "interrupted"]:
# elif e.event == "finish": yield {"error_code": 500, "text": str(e)}
# ...
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import zhipuai
params.load_config(self.model_names[0])
zhipuai.api_key = params.api_key
embeddings = []
try:
for t in params.texts:
response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t)
if response["code"] == 200:
embeddings.append(response["data"]["embedding"])
else:
return response # dict with code & msg
except Exception as e:
return {"code": 500, "msg": f"对文本向量化时出错:{e}"}
return {"code": 200, "embeddings": embeddings}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings
print("embedding") print("embedding")
@ -56,7 +72,7 @@ class ChatGLMWorker(ApiModelWorker):
name=self.model_names[0], name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务", system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
messages=[], messages=[],
roles=["Human", "Assistant"], roles=["Human", "Assistant", "System"],
sep="\n###", sep="\n###",
stop_str="###", stop_str="###",
) )

View File

@ -716,3 +716,15 @@ def get_server_configs() -> Dict:
} }
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom} return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
def list_online_embed_models() -> List[str]:
from server import model_workers
ret = []
for k, v in list_config_llm_models()["online"].items():
provider = v.get("provider")
worker_class = getattr(model_workers, provider, None)
if worker_class is not None and worker_class.can_embedding():
ret.append(k)
return ret

View File

@ -1,26 +0,0 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.azure import request_azure_api
import pytest
from configs import TEMPERATURE, ONLINE_LLM_MODEL
@pytest.mark.parametrize("version", ["azure-api"])
def test_azure(version):
messages = [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "介绍一下自己"}]
for resp in request_azure_api(
messages=messages,
api_base_url=ONLINE_LLM_MODEL["azure-api"]["api_base_url"],
api_key=ONLINE_LLM_MODEL["azure-api"]["api_key"],
deployment_name=ONLINE_LLM_MODEL["azure-api"]["deployment_name"],
api_version=ONLINE_LLM_MODEL["azure-api"]["api_version"],
temperature=TEMPERATURE,
max_tokens=1024,
model_name="azure-api"
):
assert resp["code"] == 200
if __name__ == "__main__":
test_azure("azure-api")

View File

@ -1,16 +0,0 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.baichuan import request_baichuan_api
from pprint import pprint
def test_baichuan():
messages = [{"role": "user", "content": "hello"}]
for x in request_baichuan_api(messages):
print(type(x))
pprint(x)
assert x["code"] == 0

View File

@ -1,22 +0,0 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.fangzhou import request_volc_api
from pprint import pprint
import pytest
@pytest.mark.parametrize("version", ["chatglm-6b-model"])
def test_qianfan(version):
messages = [{"role": "user", "content": "hello"}]
print("\n" + version + "\n")
i = 1
for x in request_volc_api(messages, version=version):
print(type(x))
pprint(x)
if chunk := x.choice.message.content:
print(chunk)
assert x.choice.message
i += 1

View File

@ -1,20 +0,0 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.qianfan import request_qianfan_api, MODEL_VERSIONS
from pprint import pprint
import pytest
@pytest.mark.parametrize("version", list(MODEL_VERSIONS.keys())[:2])
def test_qianfan(version):
messages = [{"role": "user", "content": "你好"}]
print("\n" + version + "\n")
i = 1
for x in request_qianfan_api(messages, version=version):
pprint(x)
assert isinstance(x, dict)
assert "error_code" not in x
i += 1

View File

@ -1,19 +0,0 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.qwen import request_qwen_api
from pprint import pprint
import pytest
@pytest.mark.parametrize("version", ["qwen-turbo"])
def test_qwen(version):
messages = [{"role": "user", "content": "hello"}]
print("\n" + version + "\n")
for x in request_qwen_api(messages, version=version):
print(type(x))
pprint(x)
assert x["code"] == 200

68
tests/test_online_api.py Normal file
View File

@ -0,0 +1,68 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent
sys.path.append(str(root_path))
from configs import ONLINE_LLM_MODEL
from server.model_workers.base import *
from server.utils import get_model_worker_config, list_config_llm_models
from pprint import pprint
import pytest
workers = []
for x in list_config_llm_models()["online"]:
if x in ONLINE_LLM_MODEL and x not in workers:
workers.append(x)
print(f"all workers to test: {workers}")
@pytest.mark.parametrize("worker", workers)
def test_chat(worker):
params = ApiChatParams(
messages = [
{"role": "user", "content": "你是谁"},
],
)
print(f"\nchat with {worker} \n")
worker_class = get_model_worker_config(worker)["worker_class"]
for x in worker_class().do_chat(params):
pprint(x)
assert isinstance(x, dict)
assert x["error_code"] == 0
@pytest.mark.parametrize("worker", workers)
def test_embeddings(worker):
params = ApiEmbeddingsParams(
texts = [
"LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。",
"一种利用 langchain 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。",
]
)
worker_class = get_model_worker_config(worker)["worker_class"]
if worker_class.can_embedding():
print(f"\embeddings with {worker} \n")
resp = worker_class().do_embeddings(params)
pprint(resp, depth=2)
assert resp["code"] == 200
assert "embeddings" in resp
embeddings = resp["embeddings"]
assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float)
print("向量长度:", len(embeddings[0]))
# @pytest.mark.parametrize("worker", workers)
# def test_completion(worker):
# params = ApiCompletionParams(prompt="五十六个民族")
# print(f"\completion with {worker} \n")
# worker_class = get_model_worker_config(worker)["worker_class"]
# resp = worker_class().do_completion(params)
# pprint(resp)