优化LLM模型列表获取、切换的逻辑: (#1497)

1、更准确的获取未运行的可用模型
2、优化WEBUI模型列表显示与切换的控制逻辑
This commit is contained in:
liunux4odoo 2023-09-16 07:15:08 +08:00 committed by GitHub
parent 955b0bc211
commit 3dde02be28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 80 additions and 33 deletions

View File

@ -40,6 +40,7 @@ MODEL_PATH = {
"chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
},
}

View File

@ -17,7 +17,7 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.llm_api import list_llm_models, change_llm_model, stop_llm_model
from server.llm_api import list_running_models,list_config_models, change_llm_model, stop_llm_model
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from typing import List
@ -125,10 +125,15 @@ def create_app():
)(recreate_vector_store)
# LLM模型相关接口
app.post("/llm_model/list_models",
app.post("/llm_model/list_running_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型",
)(list_llm_models)
)(list_running_models)
app.post("/llm_model/list_config_models",
tags=["LLM Model Management"],
summary="列出configs已配置的模型",
)(list_config_models)
app.post("/llm_model/stop",
tags=["LLM Model Management"],

View File

@ -1,10 +1,10 @@
from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import BaseResponse, fschat_controller_address
from server.utils import BaseResponse, fschat_controller_address, list_llm_models
import httpx
def list_llm_models(
def list_running_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]),
placeholder: str = Body(None, description="该参数未使用,占位用"),
) -> BaseResponse:
@ -24,6 +24,13 @@ def list_llm_models(
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
def list_config_models() -> BaseResponse:
'''
从本地获取configs中配置的模型列表
'''
return BaseResponse(data=list_llm_models())
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])

View File

@ -72,7 +72,10 @@ def request_qianfan_api(
version_url = config.get("version_url")
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
if not access_token:
raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?")
yield {
"error_code": 403,
"error_msg": f"failed to get access token. have you set the correct api_key and secret key?",
}
url = BASE_URL.format(
model_version=version_url or MODEL_VERSIONS[version],

View File

@ -5,12 +5,12 @@ from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
logger, log_verbose,
FSCHAT_MODEL_WORKERS)
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Literal, Optional, Callable, Generator, Dict, Any
from typing import Literal, Optional, Callable, Generator, Dict, Any, Tuple
thread_pool = ThreadPoolExecutor(os.cpu_count())
@ -201,10 +201,25 @@ def MakeFastAPIOffline(
# 从model_config中获取模型信息
def list_embed_models() -> List[str]:
'''
get names of configured embedding models
'''
return list(MODEL_PATH["embed_model"])
def list_llm_models() -> List[str]:
return list(MODEL_PATH["llm_model"])
def list_llm_models() -> Dict[str, List[str]]:
'''
get names of configured llm models with different types.
return [(model_name, config_type), ...]
'''
workers = list(FSCHAT_MODEL_WORKERS)
if "default" in workers:
workers.remove("default")
return {
"local": list(MODEL_PATH["llm_model"]),
"online": list(ONLINE_LLM_MODEL),
"worker": workers,
}
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
if type in MODEL_PATH:

View File

@ -68,6 +68,7 @@ def dialogue_page(api: ApiRequest):
config = get_model_worker_config(llm_model)
if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model
st.session_state["cur_llm_model"] = st.session_state.llm_model
def llm_model_format_func(x):
if x in running_models:
@ -75,25 +76,32 @@ def dialogue_page(api: ApiRequest):
return x
running_models = api.list_running_models()
available_models = []
config_models = api.list_config_models()
for x in running_models:
if x in config_models:
config_models.remove(x)
llm_models = running_models + config_models
cur_model = st.session_state.get("cur_llm_model", LLM_MODEL)
index = llm_models.index(cur_model)
for models in config_models.values():
for m in models:
if m not in running_models:
available_models.append(m)
llm_models = running_models + available_models
index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL))
llm_model = st.selectbox("选择LLM模型",
llm_models,
index,
format_func=llm_model_format_func,
on_change=on_llm_change,
# key="llm_model",
key="llm_model",
)
if (st.session_state.get("prev_llm_model") != llm_model
and not get_model_worker_config(llm_model).get("online_api")):
and not get_model_worker_config(llm_model).get("online_api")
and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
st.session_state["cur_llm_model"] = llm_model
prev_model = st.session_state.get("prev_llm_model")
r = api.change_llm_model(prev_model, llm_model)
if msg := check_error_msg(r):
st.error(msg)
elif msg := check_success_msg(r):
st.success(msg)
st.session_state["prev_llm_model"] = llm_model
temperature = st.slider("Temperature", 0.0, 1.0, TEMPERATURE, 0.05)
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)

View File

@ -766,23 +766,31 @@ class ApiRequest:
"controller_address": controller_address,
}
if no_remote_api:
from server.llm_api import list_llm_models
return list_llm_models(**data).data
from server.llm_api import list_running_models
return list_running_models(**data).data
else:
r = self.post(
"/llm_model/list_models",
"/llm_model/list_running_models",
json=data,
)
return r.json().get("data", [])
def list_config_models(self):
def list_config_models(self, no_remote_api: bool = None) -> Dict[str, List[str]]:
'''
获取configs中配置的模型列表
获取configs中配置的模型列表返回形式为{"type": [model_name1, model_name2, ...], ...}
如果no_remote_api=True, 从运行ApiRequest的机器上获取否则从运行api.py的机器上获取
'''
models = list(FSCHAT_MODEL_WORKERS.keys())
if "default" in models:
models.remove("default")
return models
if no_remote_api is None:
no_remote_api = self.no_remote_api
if no_remote_api:
from server.llm_api import list_config_models
return list_config_models().data
else:
r = self.post(
"/llm_model/list_config_models",
)
return r.json().get("data", {})
def stop_llm_model(
self,
@ -828,13 +836,13 @@ class ApiRequest:
if not model_name or not new_model_name:
return
if new_model_name == model_name:
running_models = self.list_running_models()
if new_model_name == model_name or new_model_name in running_models:
return {
"code": 200,
"msg": "什么都不用做"
"msg": "无需切换"
}
running_models = self.list_running_models()
if model_name not in running_models:
return {
"code": 500,
@ -842,7 +850,7 @@ class ApiRequest:
}
config_models = self.list_config_models()
if new_model_name not in config_models:
if new_model_name not in config_models.get("local", []):
return {
"code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"