添加切换模型功能,支持智谱AI在线模型 (#1342)

* 添加LLM模型切换功能,需要在server_config中设置可切换的模型
* add tests for api.py/llm_model/*
* - 支持模型切换
- 支持智普AI线上模型
- startup.py增加参数`--api-worker`,自动运行所有的线上API模型。使用`-a
  (--all-webui), --all-api`时默认开启该选项
* 修复被fastchat覆盖的标准输出
* 对fastchat日志进行更细致的控制,startup.py中增加-q(--quiet)开关,可以减少无用的fastchat日志输出
* 修正chatglm api的对话模板


Co-authored-by: liunux4odoo <liunu@qq.com>
This commit is contained in:
liunux4odoo 2023-09-01 23:58:09 +08:00 committed by GitHub
parent ab4c8d2e5d
commit 6cb1bdf623
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 703 additions and 94 deletions

View File

@ -7,6 +7,19 @@ logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
# 分布式部署时不运行LLM的机器上可以不装torch
def default_device():
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"
# 在以下字典中修改属性值以指定本地embedding模型存储位置
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
# 此处请写绝对路径
@ -34,7 +47,6 @@ EMBEDDING_MODEL = "m3e-base"
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "auto"
llm_model_dict = {
"chatglm-6b": {
"local_model_path": "THUDM/chatglm-6b",
@ -74,6 +86,15 @@ llm_model_dict = {
"api_key": os.environ.get("OPENAI_API_KEY"),
"openai_proxy": os.environ.get("OPENAI_PROXY")
},
# 线上模型。当前支持智谱AI。
# 如果没有设置有效的local_model_path则认为是在线模型API。
# 请在server_config中为每个在线API设置不同的端口
"chatglm-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"api_key": os.environ.get("ZHIPUAI_API_KEY"),
"provider": "ChatGLMWorker",
"version": "chatglm_pro",
},
}
# LLM 名称

View File

@ -1,4 +1,8 @@
from .model_config import LLM_MODEL, LLM_DEVICE
from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
import httpx
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
HTTPX_DEFAULT_TIMEOUT = 300.0
# API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain
@ -29,15 +33,18 @@ FSCHAT_OPENAI_API = {
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
# 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL
FSCHAT_MODEL_WORKERS = {
LLM_MODEL: {
# 所有模型共用的默认配置可在模型专项配置或llm_model_dict中进行覆盖。
"default": {
"host": DEFAULT_BIND_HOST,
"port": 20002,
"device": LLM_DEVICE,
# todo: 多卡加载需要配置的参数
# 多卡加载需要配置的参数
# "gpus": None, # 使用的GPU以str的格式指定如"0,1"
# "num_gpus": 1, # 使用GPU的数量
# 以下为非常用参数,可根据需要配置
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
# 以下为非常用参数,可根据需要配置
# "load_8bit": False, # 开启8bit量化
# "cpu_offloading": None,
# "gptq_ckpt": None,
@ -53,11 +60,17 @@ FSCHAT_MODEL_WORKERS = {
# "stream_interval": 2,
# "no_register": False,
},
"baichuan-7b": { # 使用default中的IP和端口
"device": "cpu",
},
"chatglm-api": { # 请为每个在线API设置不同的端口
"port": 20003,
},
}
# fastchat multi model worker server
FSCHAT_MULTI_MODEL_WORKERS = {
# todo
# TODO:
}
# fastchat controller server

View File

@ -4,11 +4,12 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN
from configs.model_config import LLM_MODEL, NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT
from configs import VERSION
import argparse
import uvicorn
from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
@ -17,7 +18,8 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
import httpx
from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -123,6 +125,75 @@ def create_app():
summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store)
# LLM模型相关接口
@app.post("/llm_model/list_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型")
def list_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
从fastchat controller获取已加载模型列表
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
except Exception as e:
return BaseResponse(
code=500,
data=[],
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
@app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
向fastchat controller请求停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
return BaseResponse(
code=500,
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
@app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)
def change_llm_model(
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
'''
向fastchat controller请求切换LLM模型
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
except Exception as e:
return BaseResponse(
code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
return app

View File

@ -20,11 +20,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
):
history = [History.from_data(h) for h in history]
async def chat_iterator(query: str,
history: List[History] = [],
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
@ -32,10 +34,10 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
@ -61,5 +63,5 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
await task
return StreamingResponse(chat_iterator(query, history),
return StreamingResponse(chat_iterator(query, history, model_name),
media_type="text/event-stream")

View File

@ -31,6 +31,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
):
@ -44,16 +45,17 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
kb: KBService,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
@ -97,5 +99,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
await task
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name),
media_type="text/event-stream")

View File

@ -69,6 +69,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -82,16 +83,17 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
docs = lookup_search_engine(query, search_engine_name, top_k)
@ -129,5 +131,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
ensure_ascii=False)
await task
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name),
media_type="text/event-stream")

View File

@ -6,7 +6,8 @@ from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text
from configs.model_config import kbs_config
from configs.model_config import EMBEDDING_DEVICE, kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile

View File

@ -0,0 +1 @@
from .zhipu import ChatGLMWorker

View File

@ -0,0 +1,71 @@
from configs.model_config import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel
import fastchat
import threading
from typing import Dict, List
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
class ApiModelOutMsg(BaseModel):
error_code: int = 0
text: str
class ApiModelWorker(BaseModelWorker):
BASE_URL: str
SUPPORT_MODELS: List
def __init__(
self,
model_names: List[str],
controller_addr: str,
worker_addr: str,
context_len: int = 2048,
**kwargs,
):
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
kwargs.setdefault("model_path", "")
kwargs.setdefault("limit_worker_concurrency", 5)
super().__init__(model_names=model_names,
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
self.context_len = context_len
self.init_heart_beat()
def count_token(self, params):
# TODO需要完善
print("count token")
print(params)
prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0}
def generate_stream_gate(self, params):
self.call_ct += 1
def generate_gate(self, params):
for x in self.generate_stream_gate(params):
pass
return json.loads(x[:-1].decode())
def get_embeddings(self, params):
print("embedding")
print(params)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()

View File

@ -0,0 +1,75 @@
import zhipuai
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
import json
from typing import List, Literal
class ChatGLMWorker(ApiModelWorker):
BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api"
SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"]
def __init__(
self,
*,
model_names: List[str] = ["chatglm-api"],
version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std",
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs)
self.version = version
# 这里的是chatglm api的模板其它API的conv_template需要定制
self.conv = conv.Conversation(
name="chatglm-api",
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["Human", "Assistant"],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params):
# TODO: 支持stream参数维护request_id传过来的prompt也有问题
from server.utils import get_model_worker_config
super().generate_stream_gate(params)
zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key")
response = zhipuai.model_api.sse_invoke(
model=self.version,
prompt=[{"role": "user", "content": params["prompt"]}],
temperature=params.get("temperature"),
top_p=params.get("top_p"),
incremental=False,
)
for e in response.events():
if e.event == "add":
yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0"
# TODO: 更健壮的消息处理
# elif e.event == "finish":
# ...
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = ChatGLMWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20003",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20003)

View File

@ -5,13 +5,17 @@ import torch
from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE
from typing import Literal, Optional
from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE
from configs.server_config import FSCHAT_MODEL_WORKERS
import os
from server import model_workers
from typing import Literal, Optional, Any
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")
data: Any = pydantic.Field(None, description="API data")
class Config:
schema_extra = {
@ -201,10 +205,28 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
config["device"] = llm_device(config.get("device"))
# 如果没有设置有效的local_model_path则认为是在线模型API
if not os.path.isdir(config.get("local_model_path", "")):
config["online_api"] = True
if provider := config.get("provider"):
try:
config["worker_class"] = getattr(model_workers, provider)
except Exception as e:
print(f"在线模型 {model_name} 的provider没有正确配置")
return config
def get_all_model_worker_configs() -> dict:
result = {}
model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys())
for name in model_names:
if name != "default":
result[name] = get_model_worker_config(name)
return result
def fschat_controller_address() -> str:
from configs.server_config import FSCHAT_CONTROLLER

View File

@ -16,14 +16,14 @@ except:
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout,
llm_device, embedding_device, get_model_worker_config)
from server.utils import MakeFastAPIOffline, FastAPI
get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
import argparse
from typing import Tuple, List
from typing import Tuple, List, Dict
from configs import VERSION
@ -41,6 +41,7 @@ def create_controller_app(
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
app._controller = controller
return app
@ -97,43 +98,62 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
# 在线模型API
if worker_class := kwargs.get("worker_class"):
worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address,
worker_addr=args.worker_address)
# 本地模型
else:
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def _new_init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
)
ModelWorker.init_heart_beat = _new_init_heart_beat
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({LLM_MODEL})"
app.title = f"FastChat LLM Server ({args.model_names[0]})"
app._worker = worker
return app
@ -190,6 +210,9 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
import uvicorn
import httpx
from fastapi import Body
import time
import sys
app = create_controller_app(
@ -198,11 +221,71 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
)
_set_app_seq(app, q, run_seq)
# add interface to release and load model worker
@app.post("/release_worker")
def release_worker(
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
available_models = app._controller.list_models()
if new_model_name in available_models:
msg = f"要切换的LLM模型 {new_model_name} 已经存在"
logger.info(msg)
return {"code": 500, "msg": msg}
if new_model_name:
logger.info(f"开始切换LLM模型{model_name}{new_model_name}")
else:
logger.info(f"即将停止LLM模型 {model_name}")
if model_name not in available_models:
msg = f"the model {model_name} is not available"
logger.error(msg)
return {"code": 500, "msg": msg}
worker_address = app._controller.get_worker_address(model_name)
if not worker_address:
msg = f"can not find model_worker address for {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
r = httpx.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
if new_model_name:
timer = 300 # wait 5 minutes for new model_worker register
while timer > 0:
models = app._controller.list_models()
if new_model_name in models:
break
time.sleep(1)
timer -= 1
if timer > 0:
msg = f"sucess change model from {model_name} to {new_model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
else:
msg = f"failed change model from {model_name} to {new_model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
else:
msg = f"sucess to release model: {model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
@ -214,16 +297,17 @@ def run_model_worker(
log_level: str ="INFO",
):
import uvicorn
from fastapi import Body
import sys
kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host")
port = kwargs.pop("port")
model_path = llm_model_dict[model_name].get("local_model_path", "")
kwargs["model_path"] = model_path
kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address()
kwargs["worker_address"] = fschat_model_worker_address(model_name)
model_path = kwargs.get("local_model_path", "")
kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_seq(app, q, run_seq)
@ -231,6 +315,22 @@ def run_model_worker(
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
# add interface to release and load model
@app.post("/release")
def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
if keep_origin:
if new_model_name:
q.put(["start", new_model_name])
else:
if new_model_name:
q.put(["replace", new_model_name])
else:
q.put(["stop"])
return {"code": 200, "msg": "done"}
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
@ -337,6 +437,13 @@ def parse_args() -> argparse.ArgumentParser:
help="run api.py server",
dest="api",
)
parser.add_argument(
"-p",
"--api-worker",
action="store_true",
help="run online model api such as zhipuai",
dest="api_worker",
)
parser.add_argument(
"-w",
"--webui",
@ -368,9 +475,14 @@ def dump_server_info(after_start=False, args=None):
print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {llm_device()}")
pprint(llm_model_dict[LLM_MODEL])
model = LLM_MODEL
if args and args.model_name:
model = args.model_name
print(f"当前LLM模型{model} @ {llm_device()}")
pprint(llm_model_dict[model])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start:
print("\n")
print(f"服务端运行信息:")
@ -396,17 +508,20 @@ if __name__ == "__main__":
args.openai_api = True
args.model_worker = True
args.api = True
args.api_worker = True
args.webui = True
elif args.all_api:
args.openai_api = True
args.model_worker = True
args.api = True
args.api_worker = True
args.webui = False
elif args.llm_api:
args.openai_api = True
args.model_worker = True
args.api_worker = True
args.api = False
args.webui = False
@ -416,7 +531,11 @@ if __name__ == "__main__":
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {}
processes = {"online-api": []}
def process_count():
return len(processes) + len(processes["online-api"]) -1
if args.quiet:
log_level = "ERROR"
else:
@ -426,7 +545,7 @@ if __name__ == "__main__":
process = Process(
target=run_controller,
name=f"controller({os.getpid()})",
args=(queue, len(processes) + 1, log_level),
args=(queue, process_count() + 1, log_level),
daemon=True,
)
process.start()
@ -435,29 +554,42 @@ if __name__ == "__main__":
process = Process(
target=run_openai_api,
name=f"openai_api({os.getpid()})",
args=(queue, len(processes) + 1),
args=(queue, process_count() + 1),
daemon=True,
)
process.start()
processes["openai_api"] = process
if args.model_worker:
model_path = llm_model_dict[args.model_name].get("local_model_path", "")
if os.path.isdir(model_path):
config = get_model_worker_config(args.model_name)
if not config.get("online_api"):
process = Process(
target=run_model_worker,
name=f"model_worker({os.getpid()})",
args=(args.model_name, args.controller_address, queue, len(processes) + 1, log_level),
name=f"model_worker - {args.model_name} ({os.getpid()})",
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True,
)
process.start()
processes["model_worker"] = process
if args.api_worker:
configs = get_all_model_worker_configs()
for model_name, config in configs.items():
if config.get("online_api") and config.get("worker_class"):
process = Process(
target=run_model_worker,
name=f"model_worker - {model_name} ({os.getpid()})",
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True,
)
process.start()
processes["online-api"].append(process)
if args.api:
process = Process(
target=run_api_server,
name=f"API Server{os.getpid()})",
args=(queue, len(processes) + 1),
args=(queue, process_count() + 1),
daemon=True,
)
process.start()
@ -467,37 +599,38 @@ if __name__ == "__main__":
process = Process(
target=run_webui,
name=f"WEBUI Server{os.getpid()})",
args=(queue, len(processes) + 1),
args=(queue, process_count() + 1),
daemon=True,
)
process.start()
processes["webui"] = process
if len(processes) == 0:
if process_count() == 0:
parser.print_help()
else:
try:
# log infors
while True:
no = queue.get()
if no == len(processes):
if no == process_count():
time.sleep(0.5)
dump_server_info(after_start=True, args=args)
break
else:
queue.put(no)
if model_worker_process := processes.get("model_worker"):
if model_worker_process := processes.pop("model_worker", None):
model_worker_process.join()
for process in processes.pop("online-api", []):
process.join()
for name, process in processes.items():
if name != "model_worker":
process.join()
process.join()
except:
if model_worker_process := processes.get("model_worker"):
if model_worker_process := processes.pop("model_worker", None):
model_worker_process.terminate()
for process in processes.pop("online-api", []):
process.terminate()
for name, process in processes.items():
if name != "model_worker":
process.terminate()
process.terminate()
# 服务启动后接口调用示例:

View File

@ -1,4 +1,3 @@
from doctest import testfile
import requests
import json
import sys

74
tests/api/test_llm_api.py Normal file
View File

@ -0,0 +1,74 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import api_address, FSCHAT_MODEL_WORKERS
from configs.model_config import LLM_MODEL, llm_model_dict
from pprint import pprint
import random
def get_configured_models():
model_workers = list(FSCHAT_MODEL_WORKERS)
if "default" in model_workers:
model_workers.remove("default")
llm_dict = list(llm_model_dict)
return model_workers, llm_dict
api_base_url = api_address()
def get_running_models(api="/llm_model/list_models"):
url = api_base_url + api
r = requests.post(url)
if r.status_code == 200:
return r.json()["data"]
return []
def test_running_models(api="/llm_model/list_models"):
url = api_base_url + api
r = requests.post(url)
assert r.status_code == 200
print("\n获取当前正在运行的模型列表:")
pprint(r.json())
assert isinstance(r.json()["data"], list)
assert len(r.json()["data"]) > 0
# 不建议使用stop_model功能。按现在的实现停止了就只能手动再启动
# def test_stop_model(api="/llm_model/stop"):
# url = api_base_url + api
# r = requests.post(url, json={""})
def test_change_model(api="/llm_model/change"):
url = api_base_url + api
running_models = get_running_models()
assert len(running_models) > 0
model_workers, llm_dict = get_configured_models()
availabel_new_models = set(model_workers) - set(running_models)
if len(availabel_new_models) == 0:
availabel_new_models = set(llm_dict) - set(running_models)
availabel_new_models = list(availabel_new_models)
assert len(availabel_new_models) > 0
print(availabel_new_models)
model_name = random.choice(running_models)
new_model_name = random.choice(availabel_new_models)
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
assert r.status_code == 200
running_models = get_running_models()
assert new_model_name in running_models

View File

@ -1,10 +1,14 @@
import streamlit as st
from configs.server_config import FSCHAT_MODEL_WORKERS
from webui_pages.utils import *
from streamlit_chatbox import *
from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES
from typing import List, Dict
import os
from configs.model_config import llm_model_dict, LLM_MODEL
from server.utils import get_model_worker_config
from typing import List, Dict
chat_box = ChatBox(
assistant_avatar=os.path.join(
@ -59,6 +63,38 @@ def dialogue_page(api: ApiRequest):
on_change=on_mode_change,
key="dialogue_mode",
)
def on_llm_change():
st.session_state["prev_llm_model"] = llm_model
def llm_model_format_func(x):
if x in running_models:
return f"{x} (Running)"
return x
running_models = api.list_running_models()
config_models = api.list_config_models()
for x in running_models:
if x in config_models:
config_models.remove(x)
llm_models = running_models + config_models
if "prev_llm_model" not in st.session_state:
index = llm_models.index(LLM_MODEL)
else:
index = 0
llm_model = st.selectbox("选择LLM模型",
llm_models,
index,
format_func=llm_model_format_func,
on_change=on_llm_change,
# key="llm_model",
)
if (st.session_state.get("prev_llm_model") != llm_model
and not get_model_worker_config(llm_model).get("online_api")):
with st.spinner(f"正在加载模型: {llm_model}"):
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
st.session_state["prev_llm_model"] = llm_model
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
def on_kb_change():
@ -99,7 +135,7 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...")
text = ""
r = api.chat_chat(prompt, history)
r = api.chat_chat(prompt, history=history, model=llm_model)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
@ -114,7 +150,7 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"),
])
text = ""
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
text += d["answer"]
@ -127,8 +163,8 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="网络搜索结果"),
])
text = ""
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
if error_msg := check_error_msg(d): # check whether error occured
for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
else:
text += d["answer"]

View File

@ -6,12 +6,14 @@ from configs.model_config import (
DEFAULT_VS_TYPE,
KB_ROOT_PATH,
LLM_MODEL,
llm_model_dict,
HISTORY_LEN,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
logger,
)
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
import httpx
import asyncio
from server.chat.openai_chat import OpenAiChatMsgIn
@ -42,7 +44,7 @@ class ApiRequest:
def __init__(
self,
base_url: str = "http://127.0.0.1:7861",
timeout: float = 60.0,
timeout: float = HTTPX_DEFAULT_TIMEOUT,
no_remote_api: bool = False, # call api view function directly
):
self.base_url = base_url
@ -289,6 +291,7 @@ class ApiRequest:
query: str,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
no_remote_api: bool = None,
):
'''
@ -301,6 +304,7 @@ class ApiRequest:
"query": query,
"history": history,
"stream": stream,
"model_name": model,
}
print(f"received input message:")
@ -322,6 +326,7 @@ class ApiRequest:
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
no_remote_api: bool = None,
):
'''
@ -337,6 +342,7 @@ class ApiRequest:
"score_threshold": score_threshold,
"history": history,
"stream": stream,
"model_name": model,
"local_doc_url": no_remote_api,
}
@ -361,6 +367,7 @@ class ApiRequest:
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
stream: bool = True,
model: str = LLM_MODEL,
no_remote_api: bool = None,
):
'''
@ -374,6 +381,7 @@ class ApiRequest:
"search_engine_name": search_engine_name,
"top_k": top_k,
"stream": stream,
"model_name": model,
}
print(f"received input message:")
@ -645,6 +653,84 @@ class ApiRequest:
)
return self._httpx_stream2generator(response, as_json=True)
def list_running_models(self, controller_address: str = None):
'''
获取Fastchat中正运行的模型列表
'''
r = self.post(
"/llm_model/list_models",
)
return r.json().get("data", [])
def list_config_models(self):
'''
获取configs中配置的模型列表
'''
return list(llm_model_dict.keys())
def stop_llm_model(
self,
model_name: str,
controller_address: str = None,
):
'''
停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
data = {
"model_name": model_name,
"controller_address": controller_address,
}
r = self.post(
"/llm_model/stop",
json=data,
)
return r.json()
def change_llm_model(
self,
model_name: str,
new_model_name: str,
controller_address: str = None,
):
'''
向fastchat controller请求切换LLM模型
'''
if not model_name or not new_model_name:
return
if new_model_name == model_name:
return {
"code": 200,
"msg": "什么都不用做"
}
running_models = self.list_running_models()
if model_name not in running_models:
return {
"code": 500,
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
}
config_models = self.list_config_models()
if new_model_name not in config_models:
return {
"code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
}
data = {
"model_name": model_name,
"new_model_name": new_model_name,
"controller_address": controller_address,
}
r = self.post(
"/llm_model/change",
json=data,
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
'''