diff --git a/requirements_lite.txt b/requirements_lite.txt index 127caa1..d453451 100644 --- a/requirements_lite.txt +++ b/requirements_lite.txt @@ -26,10 +26,11 @@ pytest # scikit-learn # numexpr # vllm==0.1.7; sys_platform == "linux" + # online api libs zhipuai dashscope>=1.10.0 # qwen -# qianfan +qianfan # volcengine>=1.0.106 # fangzhou # uncomment libs if you want to use corresponding vector store diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index 6d92a10..2e96853 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -26,7 +26,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), + max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), ): diff --git a/server/chat/chat.py b/server/chat/chat.py index 4402185..00d2fbb 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -7,7 +7,7 @@ from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio from langchain.prompts.chat import ChatPromptTemplate -from typing import List +from typing import List, Optional from server.chat.utils import History from server.utils import get_prompt_template @@ -22,7 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), + max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): diff --git a/server/chat/completion.py b/server/chat/completion.py index 587c9d4..a24858f 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -4,7 +4,7 @@ from configs import LLM_MODEL, TEMPERATURE from server.utils import wrap_done, get_OpenAI from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler -from typing import AsyncIterable +from typing import AsyncIterable, Optional import asyncio from langchain.prompts.chat import PromptTemplate from server.utils import get_prompt_template @@ -15,7 +15,7 @@ async def completion(query: str = Body(..., description="用户输入", examples echo: bool = Body(False, description="除了输出之外,还回显输入"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"), + max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 19ca871..b7fe85c 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -31,7 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), + max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py index 7efb0a8..4d6c58d 100644 --- a/server/chat/openai_chat.py +++ b/server/chat/openai_chat.py @@ -1,5 +1,5 @@ from fastapi.responses import StreamingResponse -from typing import List +from typing import List, Optional import openai from configs import LLM_MODEL, logger, log_verbose from server.utils import get_model_worker_config, fschat_openai_api_address @@ -16,7 +16,7 @@ class OpenAiChatMsgIn(BaseModel): messages: List[OpenAiMessage] temperature: float = 0.7 n: int = 1 - max_tokens: int = None + max_tokens: Optional[int] = None stop: List[str] = [] stream: bool = False presence_penalty: int = 0 diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 5ae8298..3d2100e 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -128,7 +128,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), + max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), split_result: bool = Body(False, description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)") ): diff --git a/server/knowledge_base/kb_cache/base.py b/server/knowledge_base/kb_cache/base.py index bb4942a..b99ca80 100644 --- a/server/knowledge_base/kb_cache/base.py +++ b/server/knowledge_base/kb_cache/base.py @@ -2,9 +2,8 @@ from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.embeddings.base import Embeddings -from langchain.schema import Document import threading -from configs import (EMBEDDING_MODEL, CHUNK_SIZE, CACHED_VS_NUM, +from configs import (EMBEDDING_MODEL, CHUNK_SIZE, logger, log_verbose) from server.utils import embedding_device, get_model_path from contextlib import contextmanager @@ -122,7 +121,9 @@ class EmbeddingsPool(CachePool): with item.acquire(msg="初始化"): self.atomic.release() if model == "text-embedding-ada-002": # openai text-embedding-ada-002 - embeddings = OpenAIEmbeddings(openai_api_key=get_model_path(model), chunk_size=CHUNK_SIZE) + embeddings = OpenAIEmbeddings(model_name=model, # TODO: 支持Azure + openai_api_key=get_model_path(model), + chunk_size=CHUNK_SIZE) elif 'bge-' in model: if 'zh' in model: # for chinese model diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py index 6285d67..4889f36 100644 --- a/server/model_workers/__init__.py +++ b/server/model_workers/__init__.py @@ -1,3 +1,4 @@ +from .base import * from .zhipu import ChatGLMWorker from .minimax import MiniMaxWorker from .xinghuo import XingHuoWorker diff --git a/server/model_workers/azure.py b/server/model_workers/azure.py index 4c438b2..3feff44 100644 --- a/server/model_workers/azure.py +++ b/server/model_workers/azure.py @@ -1,46 +1,12 @@ import sys from fastchat.conversation import Conversation -from server.model_workers.base import ApiModelWorker -from server.utils import get_model_worker_config, get_httpx_client +from server.model_workers.base import * +from server.utils import get_httpx_client from fastchat import conversation as conv import json from typing import List, Dict -from configs import TEMPERATURE -def request_azure_api( - messages: List[Dict[str, str]], - resource_name: str = None, - api_key: str = None, - deployment_name: str = None, - api_version: str = "2023-07-01-preview", - temperature: float = TEMPERATURE, - max_tokens: int = None, - model_name: str = "azure-api", -): - config = get_model_worker_config(model_name) - data = dict( - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - stream=True, - ) - url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}" - headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - 'api-key': api_key, - } - with get_httpx_client() as client: - with client.stream("POST", url, headers=headers, json=data) as response: - for line in response.iter_lines(): - if not line.strip() or "[DONE]" in line: - continue - if line.startswith("data: "): - line = line[6:] - resp = json.loads(line) - yield resp - class AzureWorker(ApiModelWorker): def __init__( self, @@ -54,37 +20,40 @@ class AzureWorker(ApiModelWorker): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 8192) super().__init__(**kwargs) - - config = self.get_config() - - self.resource_name = config.get("resource_name") - self.api_key = config.get("api_key") - self.api_version = config.get("api_version") self.version = version - self.deployment_name = config.get("deployment_name") - def generate_stream_gate(self, params): - super().generate_stream_gate(params) + def do_chat(self, params: ApiChatParams) -> Dict: + params.load_config(self.model_names[0]) + data = dict( + messages=params.messages, + temperature=params.temperature, + max_tokens=params.max_tokens, + stream=True, + ) + url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}" + .format(params.resource_name, params.deployment_name, params.api_version)) + headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + 'api-key': params.api_key, + } - messages = self.prompt_to_messages(params["prompt"]) text = "" - for resp in request_azure_api(messages=messages, - resource_name=self.resource_name, - api_key=self.api_key, - api_version=self.api_version, - deployment_name=self.deployment_name, - temperature=params.get("temperature"), - max_tokens=params.get("max_tokens")): - if choices := resp["choices"]: - if chunk := choices[0].get("delta", {}).get("content"): - text += chunk - yield json.dumps( - { - "error_code": 0, - "text": text - }, - ensure_ascii=False - ).encode() + b"\0" + with get_httpx_client() as client: + with client.stream("POST", url, headers=headers, json=data) as response: + for line in response.iter_lines(): + if not line.strip() or "[DONE]" in line: + continue + if line.startswith("data: "): + line = line[6:] + resp = json.loads(line) + if choices := resp["choices"]: + if chunk := choices[0].get("delta", {}).get("content"): + text += chunk + yield { + "error_code": 0, + "text": text + } def get_embeddings(self, params): # TODO: 支持embeddings @@ -95,7 +64,7 @@ class AzureWorker(ApiModelWorker): # TODO: 确认模板是否需要修改 return conv.Conversation( name=self.model_names[0], - system_message="", + system_message="You are a helpful, respectful and honest assistant.", messages=[], roles=["user", "assistant"], sep="\n### ", diff --git a/server/model_workers/baichuan.py b/server/model_workers/baichuan.py index 38f21d6..b8dc426 100644 --- a/server/model_workers/baichuan.py +++ b/server/model_workers/baichuan.py @@ -1,18 +1,14 @@ -# import os -# import sys -# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import json import time import hashlib from fastchat.conversation import Conversation -from server.model_workers.base import ApiModelWorker -from server.utils import get_model_worker_config, get_httpx_client +from server.model_workers.base import * +from server.utils import get_httpx_client from fastchat import conversation as conv import sys import json from typing import List, Literal, Dict -from configs import TEMPERATURE def calculate_md5(input_string): @@ -22,56 +18,12 @@ def calculate_md5(input_string): return encrypted -def request_baichuan_api( - messages: List[Dict[str, str]], - api_key: str = None, - secret_key: str = None, - version: str = "Baichuan2-53B", - temperature: float = TEMPERATURE, - model_name: str = "baichuan-api", -): - config = get_model_worker_config(model_name) - api_key = api_key or config.get("api_key") - secret_key = secret_key or config.get("secret_key") - version = version or config.get("version") - - url = "https://api.baichuan-ai.com/v1/stream/chat" - data = { - "model": version, - "messages": messages, - "parameters": {"temperature": temperature} - } - - json_data = json.dumps(data) - time_stamp = int(time.time()) - signature = calculate_md5(secret_key + json_data + str(time_stamp)) - headers = { - "Content-Type": "application/json", - "Authorization": "Bearer " + api_key, - "X-BC-Request-Id": "your requestId", - "X-BC-Timestamp": str(time_stamp), - "X-BC-Signature": signature, - "X-BC-Sign-Algo": "MD5", - } - - with get_httpx_client() as client: - with client.stream("POST", url, headers=headers, json=data) as response: - for line in response.iter_lines(): - if not line.strip(): - continue - resp = json.loads(line) - yield resp - - class BaiChuanWorker(ApiModelWorker): - BASE_URL = "https://api.baichuan-ai.com/v1/stream/chat" - SUPPORT_MODELS = ["Baichuan2-53B"] - def __init__( self, *, - controller_addr: str, - worker_addr: str, + controller_addr: str = None, + worker_addr: str = None, model_names: List[str] = ["baichuan-api"], version: Literal["Baichuan2-53B"] = "Baichuan2-53B", **kwargs, @@ -79,40 +31,48 @@ class BaiChuanWorker(ApiModelWorker): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 32768) super().__init__(**kwargs) + self.version = version - config = self.get_config() - self.version = config.get("version",version) - self.api_key = config.get("api_key") - self.secret_key = config.get("secret_key") + def do_chat(self, params: ApiChatParams) -> Dict: + params.load_config(self.model_names[0]) - def generate_stream_gate(self, params): - super().generate_stream_gate(params) + url = "https://api.baichuan-ai.com/v1/stream/chat" + data = { + "model": params.version, + "messages": params.messages, + "parameters": {"temperature": params.temperature} + } - messages = self.prompt_to_messages(params["prompt"]) + json_data = json.dumps(data) + time_stamp = int(time.time()) + signature = calculate_md5(params.secret_key + json_data + str(time_stamp)) + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + params.api_key, + "X-BC-Request-Id": "your requestId", + "X-BC-Timestamp": str(time_stamp), + "X-BC-Signature": signature, + "X-BC-Sign-Algo": "MD5", + } text = "" - for resp in request_baichuan_api(messages=messages, - api_key=self.api_key, - secret_key=self.secret_key, - version=self.version, - temperature=params.get("temperature")): - if resp["code"] == 0: - text += resp["data"]["messages"][-1]["content"] - yield json.dumps( - { - "error_code": resp["code"], - "text": text - }, - ensure_ascii=False - ).encode() + b"\0" - else: - yield json.dumps( - { - "error_code": resp["code"], - "text": resp["msg"] - }, - ensure_ascii=False - ).encode() + b"\0" + with get_httpx_client() as client: + with client.stream("POST", url, headers=headers, json=data) as response: + for line in response.iter_lines(): + if not line.strip(): + continue + resp = json.loads(line) + if resp["code"] == 0: + text += resp["data"]["messages"][-1]["content"] + yield { + "error_code": resp["code"], + "text": text + } + else: + yield { + "error_code": resp["code"], + "text": resp["msg"] + } def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/model_workers/base.py b/server/model_workers/base.py index d4632bf..95ba7c6 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -1,36 +1,106 @@ from fastchat.conversation import Conversation -from configs.basic_config import LOG_PATH +from configs import LOG_PATH, TEMPERATURE import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH from fastchat.serve.base_model_worker import BaseModelWorker import uuid import json import sys -from pydantic import BaseModel +from pydantic import BaseModel, root_validator import fastchat import asyncio -from typing import Dict, List +from server.utils import get_model_worker_config +from typing import Dict, List, Optional +__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"] + # 恢复被fastchat覆盖的标准输出 sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ -class ApiModelOutMsg(BaseModel): - error_code: int = 0 - text: str +class ApiConfigParams(BaseModel): + ''' + 在线API配置参数,未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取 + ''' + api_base_url: Optional[str] = None + api_proxy: Optional[str] = None + api_key: Optional[str] = None + secret_key: Optional[str] = None + group_id: Optional[str] = None # for minimax + is_pro: bool = False # for minimax + + APPID: Optional[str] = None # for xinghuo + APISecret: Optional[str] = None # for xinghuo + is_v2: bool = False # for xinghuo + + worker_name: Optional[str] = None + + class Config: + extra = "allow" + + @root_validator(pre=True) + def validate_config(cls, v: Dict) -> Dict: + if config := get_model_worker_config(v.get("worker_name")): + for n in cls.__fields__: + if n in config: + v[n] = config[n] + return v + + def load_config(self, worker_name: str): + self.worker_name = worker_name + if config := get_model_worker_config(worker_name): + for n in self.__fields__: + if n in config: + setattr(self, n, config[n]) + return self + + +class ApiModelParams(ApiConfigParams): + ''' + 模型配置参数 + ''' + version: Optional[str] = None + version_url: Optional[str] = None + api_version: Optional[str] = None # for azure + deployment_name: Optional[str] = None # for azure + resource_name: Optional[str] = None # for azure + + temperature: float = TEMPERATURE + max_tokens: Optional[int] = None + top_p: Optional[float] = 1.0 + + +class ApiChatParams(ApiModelParams): + ''' + chat请求参数 + ''' + messages: List[Dict[str, str]] + system_message: Optional[str] = None # for minimax + role_meta: Dict = {} # for minimax + + +class ApiCompletionParams(ApiModelParams): + prompt: str + + +class ApiEmbeddingsParams(ApiConfigParams): + texts: List[str] + embed_model: Optional[str] = None + to_query: bool = False # for minimax + class ApiModelWorker(BaseModelWorker): - BASE_URL: str - SUPPORT_MODELS: List + DEFAULT_EMBED_MODEL: str = None # None means not support embedding def __init__( self, model_names: List[str], - controller_addr: str, - worker_addr: str, + controller_addr: str = None, + worker_addr: str = None, context_len: int = 2048, + no_register: bool = False, **kwargs, ): kwargs.setdefault("worker_id", uuid.uuid4().hex[:8]) @@ -40,9 +110,18 @@ class ApiModelWorker(BaseModelWorker): controller_addr=controller_addr, worker_addr=worker_addr, **kwargs) + import sys + # 恢复被fastchat覆盖的标准输出 + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + self.context_len = context_len self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency) - self.init_heart_beat() + self.version = None + + if not no_register and self.controller_addr: + self.init_heart_beat() + def count_token(self, params): # TODO:需要完善 @@ -50,33 +129,108 @@ class ApiModelWorker(BaseModelWorker): prompt = params["prompt"] return {"count": len(str(prompt)), "error_code": 0} - def generate_stream_gate(self, params): + def generate_stream_gate(self, params: Dict): self.call_ct += 1 - + + try: + prompt = params["prompt"] + if self._is_chat(prompt): + messages = self.prompt_to_messages(prompt) + messages = self.validate_messages(messages) + else: # 使用chat模仿续写功能,不支持历史消息 + messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}] + + p = ApiChatParams( + messages=messages, + temperature=params.get("temperature"), + top_p=params.get("top_p"), + max_tokens=params.get("max_new_tokens"), + version=self.version, + ) + for resp in self.do_chat(p): + yield self._jsonify(resp) + except Exception as e: + yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误:{e}"}) + def generate_gate(self, params): - for x in self.generate_stream_gate(params): - pass - return json.loads(x[:-1].decode()) + try: + for x in self.generate_stream_gate(params): + ... + return json.loads(x[:-1].decode()) + except Exception as e: + return {"error_code": 500, "text": str(e)} + + + # 需要用户自定义的方法 + + def do_chat(self, params: ApiChatParams) -> Dict: + ''' + 执行Chat的方法,默认使用模块里面的chat函数。 + 要求返回形式:{"error_code": int, "text": str} + ''' + return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"} + + # def do_completion(self, p: ApiCompletionParams) -> Dict: + # ''' + # 执行Completion的方法,默认使用模块里面的completion函数。 + # 要求返回形式:{"error_code": int, "text": str} + # ''' + # return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"} + + def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: + ''' + 执行Embeddings的方法,默认使用模块里面的embed_documents函数。 + 要求返回形式:{"code": int, "embeddings": List[List[float]], "msg": str} + ''' + return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"} def get_embeddings(self, params): - print("embedding") - # print(params) + # fastchat对LLM做Embeddings限制很大,似乎只能使用openai的。 + # 在前端通过OpenAIEmbeddings发起的请求直接出错,无法请求过来。 + print("get_embedding") + print(params) def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: raise NotImplementedError + def validate_messages(self, messages: List[Dict]) -> List[Dict]: + ''' + 有些API对mesages有特殊格式,可以重写该函数替换默认的messages。 + 之所以跟prompt_to_messages分开,是因为他们应用场景不同、参数不同 + ''' + return messages + + # help methods - def get_config(self): - from server.utils import get_model_worker_config - return get_model_worker_config(self.model_names[0]) + @property + def user_role(self): + return self.conv.roles[0] + + @property + def ai_role(self): + return self.conv.roles[1] + + def _jsonify(self, data: Dict) -> str: + ''' + 将chat函数返回的结果按照fastchat openai-api-server的格式返回 + ''' + return json.dumps(data, ensure_ascii=False).encode() + b"\0" + + def _is_chat(self, prompt: str) -> bool: + ''' + 检查prompt是否由chat messages拼接而来 + TODO: 存在误判的可能,也许从fastchat直接传入原始messages是更好的做法 + ''' + key = f"{self.conv.sep}{self.user_role}:" + return key in prompt def prompt_to_messages(self, prompt: str) -> List[Dict]: ''' 将prompt字符串拆分成messages. ''' result = [] - user_role = self.conv.roles[0] - ai_role = self.conv.roles[1] + user_role = self.user_role + ai_role = self.ai_role user_start = user_role + ":" ai_start = ai_role + ":" for msg in prompt.split(self.conv.sep)[1:-1]: @@ -89,3 +243,7 @@ class ApiModelWorker(BaseModelWorker): else: raise RuntimeError(f"unknown role in msg: {msg}") return result + + @classmethod + def can_embedding(cls): + return cls.DEFAULT_EMBED_MODEL is not None diff --git a/server/model_workers/fangzhou.py b/server/model_workers/fangzhou.py index 8243b18..dc4d0f8 100644 --- a/server/model_workers/fangzhou.py +++ b/server/model_workers/fangzhou.py @@ -1,98 +1,58 @@ from fastchat.conversation import Conversation -from server.model_workers.base import ApiModelWorker -from configs.model_config import TEMPERATURE +from server.model_workers.base import * from fastchat import conversation as conv import sys -import json -from pprint import pprint -from server.utils import get_model_worker_config from typing import List, Literal, Dict -def request_volc_api( - messages: List[Dict], - model_name: str = "fangzhou-api", - version: str = "chatglm-6b-model", - temperature: float = TEMPERATURE, - api_key: str = None, - secret_key: str = None, -): - from volcengine.maas import MaasService, MaasException, ChatRole - - maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') - config = get_model_worker_config(model_name) - version = version or config.get("version") - version_url = config.get("version_url") - api_key = api_key or config.get("api_key") - secret_key = secret_key or config.get("secret_key") - - maas.set_ak(api_key) - maas.set_sk(secret_key) - - # document: "https://www.volcengine.com/docs/82379/1099475" - req = { - "model": { - "name": version, - }, - "parameters": { - # 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明 - "max_new_tokens": 1000, - "temperature": temperature, - }, - "messages": messages, - } - - try: - resps = maas.stream_chat(req) - for resp in resps: - yield resp - except MaasException as e: - print(e) - - class FangZhouWorker(ApiModelWorker): """ 火山方舟 """ - SUPPORT_MODELS = ["chatglm-6b-model"] def __init__( - self, - *, - version: Literal["chatglm-6b-model"] = "chatglm-6b-model", - model_names: List[str] = ["fangzhou-api"], - controller_addr: str, - worker_addr: str, - **kwargs, + self, + *, + model_names: List[str] = ["fangzhou-api"], + controller_addr: str = None, + worker_addr: str = None, + version: Literal["chatglm-6b-model"] = "chatglm-6b-model", + **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小 - super().__init__(**kwargs) - - config = self.get_config() self.version = version - self.api_key = config.get("api_key") - self.secret_key = config.get("secret_key") - def generate_stream_gate(self, params): - super().generate_stream_gate(params) + def do_chat(self, params: ApiChatParams) -> Dict: + from volcengine.maas import MaasService + + params.load_config(self.model_names[0]) + maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') + maas.set_ak(params.api_key) + maas.set_sk(params.secret_key) + + # document: "https://www.volcengine.com/docs/82379/1099475" + req = { + "model": { + "name": params.version, + }, + "parameters": { + # 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明 + "max_new_tokens": params.max_tokens, + "temperature": params.temperature, + }, + "messages": params.messages, + } - messages = self.prompt_to_messages(params["prompt"]) text = "" - - for resp in request_volc_api(messages=messages, - model_name=self.model_names[0], - version=self.version, - temperature=params.get("temperature", TEMPERATURE), - ): + for resp in maas.stream_chat(req): error = resp.error if error.code_n > 0: - data = {"error_code": error.code_n, "text": error.message} + yield {"error_code": error.code_n, "text": error.message} elif chunk := resp.choice.message.content: text += chunk - data = {"error_code": 0, "text": text} - yield json.dumps(data, ensure_ascii=False).encode() + b"\0" + yield {"error_code": 0, "text": text} def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index 54eaf07..caffcba 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -1,77 +1,108 @@ from fastchat.conversation import Conversation -from server.model_workers.base import ApiModelWorker +from server.model_workers.base import * from fastchat import conversation as conv import sys import json +from server.model_workers.base import ApiEmbeddingsParams from server.utils import get_httpx_client -from pprint import pprint from typing import List, Dict class MiniMaxWorker(ApiModelWorker): - BASE_URL = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}' + DEFAULT_EMBED_MODEL = "embo-01" def __init__( self, *, model_names: List[str] = ["minimax-api"], - controller_addr: str, - worker_addr: str, + controller_addr: str = None, + worker_addr: str = None, + version: str = "abab5.5-chat", **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 16384) super().__init__(**kwargs) + self.version = version - def prompt_to_messages(self, prompt: str) -> List[Dict]: - result = super().prompt_to_messages(prompt) - messages = [{"sender_type": x["role"], "text": x["content"]} for x in result] + def validate_messages(self, messages: List[Dict]) -> List[Dict]: + role_maps = { + "user": self.user_role, + "assistant": self.ai_role, + "system": "system", + } + messages = [{"sender_type": role_maps[x["role"]], "text": x["content"]} for x in messages] return messages - def generate_stream_gate(self, params): + def do_chat(self, params: ApiChatParams) -> Dict: # 按照官网推荐,直接调用abab 5.5模型 # TODO: 支持指定回复要求,支持指定用户名称、AI名称 + params.load_config(self.model_names[0]) - super().generate_stream_gate(params) - config = self.get_config() - group_id = config.get("group_id") - api_key = config.get("api_key") - - pro = "_pro" if config.get("is_pro") else "" + url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}' + pro = "_pro" if params.is_pro else "" headers = { - "Authorization": f"Bearer {api_key}", + "Authorization": f"Bearer {params.api_key}", "Content-Type": "application/json", } + messages = self.validate_messages(params.messages) data = { - "model": "abab5.5-chat", + "model": params.version, "stream": True, - "tokens_to_generate": 1024, # TODO: 1024为官网默认值 "mask_sensitive_info": True, - "messages": self.prompt_to_messages(params["prompt"]), - "temperature": params.get("temperature"), - "top_p": params.get("top_p"), - "bot_setting": [], + "messages": messages, + "temperature": params.temperature, + "top_p": params.top_p, + "tokens_to_generate": params.max_tokens or 1024, + # TODO: 以下参数为minimax特有,传入空值会出错。 + # "prompt": params.system_message or self.conv.system_message, + # "bot_setting": [], + # "role_meta": params.role_meta, } - print("request data sent to minimax:") - pprint(data) + with get_httpx_client() as client: response = client.stream("POST", - self.BASE_URL.format(pro=pro, group_id=group_id), + url.format(pro=pro, group_id=params.group_id), headers=headers, json=data) with response as r: text = "" for e in r.iter_text(): - if e.startswith("data: "): # 真是优秀的返回 - data = json.loads(e[6:]) - if not data.get("usage"): - if choices := data.get("choices"): - chunk = choices[0].get("delta", "").strip() - if chunk: - print(chunk) - text += chunk - yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0" - + if not e.startswith("data: "): # 真是优秀的返回 + yield {"error_code": 500, "text": f"minimax返回错误的结果:{e}"} + continue + + data = json.loads(e[6:]) + if data.get("usage"): + break + + if choices := data.get("choices"): + if chunk := choices[0].get("delta", ""): + text += chunk + yield {"error_code": 0, "text": text} + + def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: + params.load_config(self.model_names[0]) + url = f"https://api.minimax.chat/v1/embeddings?GroupId={params.group_id}" + + headers = { + "Authorization": f"Bearer {params.api_key}", + "Content-Type": "application/json", + } + + data = { + "model": params.embed_model or self.DEFAULT_EMBED_MODEL, + "texts": params.texts, + "type": "query" if params.to_query else "db", + } + + with get_httpx_client() as client: + r = client.post(url, headers=headers, json=data).json() + if embeddings := r.get("vectors"): + return {"code": 200, "embeddings": embeddings} + elif error := r.get("base_resp"): + return {"code": error["status_code"], "msg": error["status_msg"]} + def get_embeddings(self, params): # TODO: 支持embeddings print("embedding") @@ -81,7 +112,7 @@ class MiniMaxWorker(ApiModelWorker): # TODO: 确认模板是否需要修改 return conv.Conversation( name=self.model_names[0], - system_message="", + system_message="你是MiniMax自主研发的大型语言模型,回答问题简洁有条理。", messages=[], roles=["USER", "BOT"], sep="\n### ", diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 07a64a4..2160db3 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -1,155 +1,169 @@ import sys from fastchat.conversation import Conversation -from server.model_workers.base import ApiModelWorker -from configs.model_config import TEMPERATURE +from server.model_workers.base import * from fastchat import conversation as conv -import json -from cachetools import cached, TTLCache -from server.utils import get_model_worker_config, get_httpx_client +import sys +from server.model_workers.base import ApiEmbeddingsParams from typing import List, Literal, Dict -MODEL_VERSIONS = { - "ernie-bot": "completions", - "ernie-bot-4": "completions_pro", - "ernie-bot-turbo": "eb-instant", - "bloomz-7b": "bloomz_7b1", - "qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed", - "llama2-7b-chat": "llama_2_7b", - "llama2-13b-chat": "llama_2_13b", - "llama2-70b-chat": "llama_2_70b", - "qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b", - "chatglm2-6b-32k": "chatglm2_6b_32k", - "aquilachat-7b": "aquilachat_7b", - # "linly-llama2-ch-7b": "", # 暂未发布 - # "linly-llama2-ch-13b": "", # 暂未发布 - # "chatglm2-6b": "", # 暂未发布 - # "chatglm2-6b-int4": "", # 暂未发布 - # "falcon-7b": "", # 暂未发布 - # "falcon-180b-chat": "", # 暂未发布 - # "falcon-40b": "", # 暂未发布 - # "rwkv4-world": "", # 暂未发布 - # "rwkv5-world": "", # 暂未发布 - # "rwkv4-pile-14b": "", # 暂未发布 - # "rwkv4-raven-14b": "", # 暂未发布 - # "open-llama-7b": "", # 暂未发布 - # "dolly-12b": "", # 暂未发布 - # "mpt-7b-instruct": "", # 暂未发布 - # "mpt-30b-instruct": "", # 暂未发布 - # "OA-Pythia-12B-SFT-4": "", # 暂未发布 - # "xverse-13b": "", # 暂未发布 - # # 以下为企业测试,需要单独申请 - # "flan-ul2": "", - # "Cerebras-GPT-6.7B": "" - # "Pythia-6.9B": "" -} +# MODEL_VERSIONS = { +# "ernie-bot": "completions", +# "ernie-bot-turbo": "eb-instant", +# "bloomz-7b": "bloomz_7b1", +# "qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed", +# "llama2-7b-chat": "llama_2_7b", +# "llama2-13b-chat": "llama_2_13b", +# "llama2-70b-chat": "llama_2_70b", +# "qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b", +# "chatglm2-6b-32k": "chatglm2_6b_32k", +# "aquilachat-7b": "aquilachat_7b", +# # "linly-llama2-ch-7b": "", # 暂未发布 +# # "linly-llama2-ch-13b": "", # 暂未发布 +# # "chatglm2-6b": "", # 暂未发布 +# # "chatglm2-6b-int4": "", # 暂未发布 +# # "falcon-7b": "", # 暂未发布 +# # "falcon-180b-chat": "", # 暂未发布 +# # "falcon-40b": "", # 暂未发布 +# # "rwkv4-world": "", # 暂未发布 +# # "rwkv5-world": "", # 暂未发布 +# # "rwkv4-pile-14b": "", # 暂未发布 +# # "rwkv4-raven-14b": "", # 暂未发布 +# # "open-llama-7b": "", # 暂未发布 +# # "dolly-12b": "", # 暂未发布 +# # "mpt-7b-instruct": "", # 暂未发布 +# # "mpt-30b-instruct": "", # 暂未发布 +# # "OA-Pythia-12B-SFT-4": "", # 暂未发布 +# # "xverse-13b": "", # 暂未发布 + +# # # 以下为企业测试,需要单独申请 +# # "flan-ul2": "", +# # "Cerebras-GPT-6.7B": "" +# # "Pythia-6.9B": "" +# } -@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次 -def get_baidu_access_token(api_key: str, secret_key: str) -> str: - """ - 使用 AK,SK 生成鉴权签名(Access Token) - :return: access_token,或是None(如果错误) - """ - url = "https://aip.baidubce.com/oauth/2.0/token" - params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} - try: - with get_httpx_client() as client: - return client.get(url, params=params).json().get("access_token") - except Exception as e: - print(f"failed to get token from baidu: {e}") - - -def request_qianfan_api( - messages: List[Dict[str, str]], - temperature: float = TEMPERATURE, - model_name: str = "qianfan-api", - version: str = None, -) -> Dict: - BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \ - '/{model_version}?access_token={access_token}' - config = get_model_worker_config(model_name) - version = version or config.get("version") - version_url = config.get("version_url") - access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key")) - if not access_token: - yield { - "error_code": 403, - "error_msg": f"failed to get access token. have you set the correct api_key and secret key?", - } - - url = BASE_URL.format( - model_version=version_url or MODEL_VERSIONS[version], - access_token=access_token, - ) - - payload = { - "messages": messages, - "temperature": temperature, - "stream": True - } - headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - } - - with get_httpx_client() as client: - with client.stream("POST", url, headers=headers, json=payload) as response: - for line in response.iter_lines(): - if not line.strip(): - continue - if line.startswith("data: "): - line = line[6:] - resp = json.loads(line) - breakpoint() - yield resp +# @cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次 +# def get_baidu_access_token(api_key: str, secret_key: str) -> str: +# """ +# 使用 AK,SK 生成鉴权签名(Access Token) +# :return: access_token,或是None(如果错误) +# """ +# url = "https://aip.baidubce.com/oauth/2.0/token" +# params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} +# try: +# with get_httpx_client() as client: +# return client.get(url, params=params).json().get("access_token") +# except Exception as e: +# print(f"failed to get token from baidu: {e}") class QianFanWorker(ApiModelWorker): """ 百度千帆 """ + DEFAULT_EMBED_MODEL = "bge-large-zh" def __init__( - self, - *, - version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot", - model_names: List[str] = ["ernie-api"], - controller_addr: str, - worker_addr: str, - **kwargs, + self, + *, + version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot", + model_names: List[str] = ["qianfan-api"], + controller_addr: str = None, + worker_addr: str = None, + **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 16384) super().__init__(**kwargs) - - config = self.get_config() self.version = version - self.api_key = config.get("api_key") - self.secret_key = config.get("secret_key") - def generate_stream_gate(self, params): - messages = self.prompt_to_messages(params["prompt"]) + def do_chat(self, params: ApiChatParams) -> Dict: + params.load_config(self.model_names[0]) + import qianfan + + comp = qianfan.ChatCompletion(model=params.version, + endpoint=params.version_url, + ak=params.api_key, + sk=params.secret_key,) text = "" - for resp in request_qianfan_api(messages, - temperature=params.get("temperature"), - model_name=self.model_names[0]): - if "result" in resp.keys(): - text += resp["result"] - yield json.dumps({ - "error_code": 0, - "text": text - }, - ensure_ascii=False - ).encode() + b"\0" + for resp in comp.do(messages=params.messages, + temperature=params.temperature, + top_p=params.top_p, + stream=True): + if resp.code == 200: + if chunk := resp.body.get("result"): + text += chunk + yield { + "error_code": 0, + "text": text + } else: - yield json.dumps({ - "error_code": resp["error_code"], - "text": resp["error_msg"] - }, - ensure_ascii=False - ).encode() + b"\0" + yield { + "error_code": resp.code, + "text": str(resp.body), + } + # BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\ + # '/{model_version}?access_token={access_token}' + + # access_token = get_baidu_access_token(params.api_key, params.secret_key) + # if not access_token: + # yield { + # "error_code": 403, + # "text": f"failed to get access token. have you set the correct api_key and secret key?", + # } + + # url = BASE_URL.format( + # model_version=params.version_url or MODEL_VERSIONS[params.version], + # access_token=access_token, + # ) + # payload = { + # "messages": params.messages, + # "temperature": params.temperature, + # "stream": True + # } + # headers = { + # 'Content-Type': 'application/json', + # 'Accept': 'application/json', + # } + + # text = "" + # with get_httpx_client() as client: + # with client.stream("POST", url, headers=headers, json=payload) as response: + # for line in response.iter_lines(): + # if not line.strip(): + # continue + # if line.startswith("data: "): + # line = line[6:] + # resp = json.loads(line) + + # if "result" in resp.keys(): + # text += resp["result"] + # yield { + # "error_code": 0, + # "text": text + # } + # else: + # yield { + # "error_code": resp["error_code"], + # "text": resp["error_msg"] + # } + + def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: + import qianfan + params.load_config(self.model_names[0]) + + embed = qianfan.Embedding(ak=params.api_key, sk=params.secret_key) + resp = embed.do(texts = params.texts, model=params.embed_model or self.DEFAULT_EMBED_MODEL) + if resp.code == 200: + embeddings = [x.embedding for x in resp.body.get("data", [])] + return {"code": 200, "embeddings": embeddings} + else: + return {"code": resp.code, "msg": str(resp.body)} + + + # TODO: qianfan支持续写模型 def get_embeddings(self, params): # TODO: 支持embeddings print("embedding") @@ -159,7 +173,7 @@ class QianFanWorker(ApiModelWorker): # TODO: 确认模板是否需要修改 return conv.Conversation( name=self.model_names[0], - system_message="", + system_message="你是一个聪明的助手,请根据用户的提示来完成任务", messages=[], roles=["user", "assistant"], sep="\n### ", diff --git a/server/model_workers/qwen.py b/server/model_workers/qwen.py index ab51b8a..2340ec2 100644 --- a/server/model_workers/qwen.py +++ b/server/model_workers/qwen.py @@ -7,94 +7,68 @@ from http import HTTPStatus from typing import List, Literal, Dict from fastchat import conversation as conv - -from server.model_workers.base import ApiModelWorker -from server.utils import get_model_worker_config - - -def request_qwen_api( - messages: List[Dict[str, str]], - api_key: str = None, - version: str = "qwen-turbo", - temperature: float = TEMPERATURE, - model_name: str = "qwen-api", -): - import dashscope - - config = get_model_worker_config(model_name) - api_key = api_key or config.get("api_key") - version = version or config.get("version") - - gen = dashscope.Generation() - responses = gen.call( - model=version, - temperature=temperature, - api_key=api_key, - messages=messages, - result_format='message', # set the result is message format. - stream=True, - ) - - text = "" - for resp in responses: - if resp.status_code != HTTPStatus.OK: - yield { - "code": resp.status_code, - "text": "api not response correctly", - } - - if resp["status_code"] == 200: - if choices := resp["output"]["choices"]: - yield { - "code": 200, - "text": choices[0]["message"]["content"], - } - else: - yield { - "code": resp["status_code"], - "text": resp["message"], - } +from server.model_workers.base import * +from server.model_workers.base import ApiEmbeddingsParams class QwenWorker(ApiModelWorker): + DEFAULT_EMBED_MODEL = "text-embedding-v1" + def __init__( - self, - *, - version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo", - model_names: List[str] = ["qwen-api"], - controller_addr: str, - worker_addr: str, - **kwargs, + self, + *, + version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo", + model_names: List[str] = ["qwen-api"], + controller_addr: str = None, + worker_addr: str = None, + **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 16384) super().__init__(**kwargs) - - config = self.get_config() - self.api_key = config.get("api_key") self.version = version - def generate_stream_gate(self, params): - messages = self.prompt_to_messages(params["prompt"]) + def do_chat(self, params: ApiChatParams) -> Dict: + import dashscope + params.load_config(self.model_names[0]) - for resp in request_qwen_api(messages=messages, - api_key=self.api_key, - version=self.version, - temperature=params.get("temperature")): - if resp["code"] == 200: - yield json.dumps({ - "error_code": 0, - "text": resp["text"] - }, - ensure_ascii=False - ).encode() + b"\0" + gen = dashscope.Generation() + responses = gen.call( + model=params.version, + temperature=params.temperature, + api_key=params.api_key, + messages=params.messages, + result_format='message', # set the result is message format. + stream=True, + ) + + for resp in responses: + if resp["status_code"] == 200: + if choices := resp["output"]["choices"]: + yield { + "error_code": 0, + "text": choices[0]["message"]["content"], + } else: - yield json.dumps({ - "error_code": resp["code"], - "text": resp["text"] - }, - ensure_ascii=False - ).encode() + b"\0" + yield { + "error_code": resp["status_code"], + "text": resp["message"], + } + + def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: + import dashscope + params.load_config(self.model_names[0]) + + resp = dashscope.TextEmbedding.call( + model=params.embed_model or self.DEFAULT_EMBED_MODEL, + input=params.texts, # 最大25行 + api_key=params.api_key, + ) + if resp["status_code"] != 200: + return {"code": resp["status_code"], "msg": resp.message} + else: + embeddings = [x["embedding"] for x in resp["output"]["embeddings"]] + return {"code": 200, "embeddings": embeddings} def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py index 3091c94..0a3a9c8 100644 --- a/server/model_workers/xinghuo.py +++ b/server/model_workers/xinghuo.py @@ -1,12 +1,12 @@ from fastchat.conversation import Conversation -from server.model_workers.base import ApiModelWorker +from server.model_workers.base import * from fastchat import conversation as conv import sys import json from server.model_workers import SparkApi import websockets from server.utils import iter_over_async, asyncio -from typing import List +from typing import List, Dict async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature): @@ -31,46 +31,40 @@ class XingHuoWorker(ApiModelWorker): self, *, model_names: List[str] = ["xinghuo-api"], - controller_addr: str, - worker_addr: str, + controller_addr: str = None, + worker_addr: str = None, + version: str = None, **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 8192) super().__init__(**kwargs) + self.version = version - def generate_stream_gate(self, params): + def do_chat(self, params: ApiChatParams) -> Dict: # TODO: 当前每次对话都要重新连接websocket,确认是否可以保持连接 + params.load_config(self.model_names[0]) - super().generate_stream_gate(params) - config = self.get_config() - appid = config.get("APPID") - api_secret = config.get("APISecret") - api_key = config.get("api_key") - - if config.get("is_v2"): + if params.is_v2: domain = "generalv2" # v2.0版本 Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址 else: domain = "general" # v1.5版本 Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址 - question = self.prompt_to_messages(params["prompt"]) text = "" - try: loop = asyncio.get_event_loop() except: loop = asyncio.new_event_loop() for chunk in iter_over_async( - request(appid, api_key, api_secret, Spark_url, domain, question, params.get("temperature")), + request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages, params.temperature), loop=loop, ): if chunk: - print(chunk) text += chunk - yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0" + yield {"error_code": 0, "text": text} def get_embeddings(self, params): # TODO: 支持embeddings @@ -81,7 +75,7 @@ class XingHuoWorker(ApiModelWorker): # TODO: 确认模板是否需要修改 return conv.Conversation( name=self.model_names[0], - system_message="", + system_message="你是一个聪明的助手,请根据用户的提示来完成任务", messages=[], roles=["user", "assistant"], sep="\n### ", diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 34e1396..546dbce 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -1,22 +1,20 @@ from fastchat.conversation import Conversation -from server.model_workers.base import ApiModelWorker +from server.model_workers.base import * from fastchat import conversation as conv import sys -import json -from typing import List, Literal +from typing import List, Dict, Iterator, Literal class ChatGLMWorker(ApiModelWorker): - BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api" - SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"] + DEFAULT_EMBED_MODEL = "text_embedding" def __init__( self, *, model_names: List[str] = ["zhipu-api"], + controller_addr: str = None, + worker_addr: str = None, version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std", - controller_addr: str, - worker_addr: str, **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) @@ -24,27 +22,45 @@ class ChatGLMWorker(ApiModelWorker): super().__init__(**kwargs) self.version = version - def generate_stream_gate(self, params): + def do_chat(self, params: ApiChatParams) -> Iterator[Dict]: # TODO: 维护request_id import zhipuai - super().generate_stream_gate(params) - zhipuai.api_key = self.get_config().get("api_key") + params.load_config(self.model_names[0]) + zhipuai.api_key = params.api_key response = zhipuai.model_api.sse_invoke( - model=self.version, - prompt=[{"role": "user", "content": params["prompt"]}], - temperature=params.get("temperature"), - top_p=params.get("top_p"), + model=params.version, + prompt=params.messages, + temperature=params.temperature, + top_p=params.top_p, incremental=False, ) for e in response.events(): if e.event == "add": - yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0" - # TODO: 更健壮的消息处理 - # elif e.event == "finish": - # ... - + yield {"error_code": 0, "text": e.data} + elif e.event in ["error", "interrupted"]: + yield {"error_code": 500, "text": str(e)} + + def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: + import zhipuai + + params.load_config(self.model_names[0]) + zhipuai.api_key = params.api_key + + embeddings = [] + try: + for t in params.texts: + response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t) + if response["code"] == 200: + embeddings.append(response["data"]["embedding"]) + else: + return response # dict with code & msg + except Exception as e: + return {"code": 500, "msg": f"对文本向量化时出错:{e}"} + + return {"code": 200, "embeddings": embeddings} + def get_embeddings(self, params): # TODO: 支持embeddings print("embedding") @@ -56,7 +72,7 @@ class ChatGLMWorker(ApiModelWorker): name=self.model_names[0], system_message="你是一个聪明的助手,请根据用户的提示来完成任务", messages=[], - roles=["Human", "Assistant"], + roles=["Human", "Assistant", "System"], sep="\n###", stop_str="###", ) diff --git a/server/utils.py b/server/utils.py index 90c825d..39c1cb3 100644 --- a/server/utils.py +++ b/server/utils.py @@ -716,3 +716,15 @@ def get_server_configs() -> Dict: } return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom} + + +def list_online_embed_models() -> List[str]: + from server import model_workers + + ret = [] + for k, v in list_config_llm_models()["online"].items(): + provider = v.get("provider") + worker_class = getattr(model_workers, provider, None) + if worker_class is not None and worker_class.can_embedding(): + ret.append(k) + return ret diff --git a/tests/online_api/test_azure.py b/tests/online_api/test_azure.py deleted file mode 100644 index 2432560..0000000 --- a/tests/online_api/test_azure.py +++ /dev/null @@ -1,26 +0,0 @@ -import sys -from pathlib import Path - -root_path = Path(__file__).parent.parent.parent -sys.path.append(str(root_path)) -from server.model_workers.azure import request_azure_api -import pytest -from configs import TEMPERATURE, ONLINE_LLM_MODEL -@pytest.mark.parametrize("version", ["azure-api"]) -def test_azure(version): - messages = [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "介绍一下自己"}] - for resp in request_azure_api( - messages=messages, - api_base_url=ONLINE_LLM_MODEL["azure-api"]["api_base_url"], - api_key=ONLINE_LLM_MODEL["azure-api"]["api_key"], - deployment_name=ONLINE_LLM_MODEL["azure-api"]["deployment_name"], - api_version=ONLINE_LLM_MODEL["azure-api"]["api_version"], - temperature=TEMPERATURE, - max_tokens=1024, - model_name="azure-api" - ): - assert resp["code"] == 200 - - -if __name__ == "__main__": - test_azure("azure-api") \ No newline at end of file diff --git a/tests/online_api/test_baichuan.py b/tests/online_api/test_baichuan.py deleted file mode 100644 index cbe4843..0000000 --- a/tests/online_api/test_baichuan.py +++ /dev/null @@ -1,16 +0,0 @@ -import sys -from pathlib import Path -root_path = Path(__file__).parent.parent.parent -sys.path.append(str(root_path)) - -from server.model_workers.baichuan import request_baichuan_api -from pprint import pprint - - -def test_baichuan(): - messages = [{"role": "user", "content": "hello"}] - - for x in request_baichuan_api(messages): - print(type(x)) - pprint(x) - assert x["code"] == 0 \ No newline at end of file diff --git a/tests/online_api/test_fangzhou.py b/tests/online_api/test_fangzhou.py deleted file mode 100644 index 1157537..0000000 --- a/tests/online_api/test_fangzhou.py +++ /dev/null @@ -1,22 +0,0 @@ -import sys -from pathlib import Path -root_path = Path(__file__).parent.parent.parent -sys.path.append(str(root_path)) - -from server.model_workers.fangzhou import request_volc_api -from pprint import pprint -import pytest - - -@pytest.mark.parametrize("version", ["chatglm-6b-model"]) -def test_qianfan(version): - messages = [{"role": "user", "content": "hello"}] - print("\n" + version + "\n") - i = 1 - for x in request_volc_api(messages, version=version): - print(type(x)) - pprint(x) - if chunk := x.choice.message.content: - print(chunk) - assert x.choice.message - i += 1 diff --git a/tests/online_api/test_qianfan.py b/tests/online_api/test_qianfan.py deleted file mode 100644 index 21f6390..0000000 --- a/tests/online_api/test_qianfan.py +++ /dev/null @@ -1,20 +0,0 @@ -import sys -from pathlib import Path -root_path = Path(__file__).parent.parent.parent -sys.path.append(str(root_path)) - -from server.model_workers.qianfan import request_qianfan_api, MODEL_VERSIONS -from pprint import pprint -import pytest - - -@pytest.mark.parametrize("version", list(MODEL_VERSIONS.keys())[:2]) -def test_qianfan(version): - messages = [{"role": "user", "content": "你好"}] - print("\n" + version + "\n") - i = 1 - for x in request_qianfan_api(messages, version=version): - pprint(x) - assert isinstance(x, dict) - assert "error_code" not in x - i += 1 \ No newline at end of file diff --git a/tests/online_api/test_qwen.py b/tests/online_api/test_qwen.py deleted file mode 100644 index 001cf60..0000000 --- a/tests/online_api/test_qwen.py +++ /dev/null @@ -1,19 +0,0 @@ -import sys -from pathlib import Path -root_path = Path(__file__).parent.parent.parent -sys.path.append(str(root_path)) - -from server.model_workers.qwen import request_qwen_api -from pprint import pprint -import pytest - - -@pytest.mark.parametrize("version", ["qwen-turbo"]) -def test_qwen(version): - messages = [{"role": "user", "content": "hello"}] - print("\n" + version + "\n") - - for x in request_qwen_api(messages, version=version): - print(type(x)) - pprint(x) - assert x["code"] == 200 diff --git a/tests/test_online_api.py b/tests/test_online_api.py new file mode 100644 index 0000000..0ddeaee --- /dev/null +++ b/tests/test_online_api.py @@ -0,0 +1,68 @@ +import sys +from pathlib import Path +root_path = Path(__file__).parent.parent +sys.path.append(str(root_path)) + +from configs import ONLINE_LLM_MODEL +from server.model_workers.base import * +from server.utils import get_model_worker_config, list_config_llm_models +from pprint import pprint +import pytest + + +workers = [] +for x in list_config_llm_models()["online"]: + if x in ONLINE_LLM_MODEL and x not in workers: + workers.append(x) +print(f"all workers to test: {workers}") + + +@pytest.mark.parametrize("worker", workers) +def test_chat(worker): + params = ApiChatParams( + messages = [ + {"role": "user", "content": "你是谁"}, + ], + ) + print(f"\nchat with {worker} \n") + + worker_class = get_model_worker_config(worker)["worker_class"] + for x in worker_class().do_chat(params): + pprint(x) + assert isinstance(x, dict) + assert x["error_code"] == 0 + + +@pytest.mark.parametrize("worker", workers) +def test_embeddings(worker): + params = ApiEmbeddingsParams( + texts = [ + "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。", + "一种利用 langchain 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。", + ] + ) + + worker_class = get_model_worker_config(worker)["worker_class"] + if worker_class.can_embedding(): + print(f"\embeddings with {worker} \n") + resp = worker_class().do_embeddings(params) + + pprint(resp, depth=2) + assert resp["code"] == 200 + assert "embeddings" in resp + embeddings = resp["embeddings"] + assert isinstance(embeddings, list) and len(embeddings) > 0 + assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0 + assert isinstance(embeddings[0][0], float) + print("向量长度:", len(embeddings[0])) + + +# @pytest.mark.parametrize("worker", workers) +# def test_completion(worker): +# params = ApiCompletionParams(prompt="五十六个民族") + +# print(f"\completion with {worker} \n") + +# worker_class = get_model_worker_config(worker)["worker_class"] +# resp = worker_class().do_completion(params) +# pprint(resp)