优化LLM模型列表获取、切换的逻辑: (#1497)

1、更准确的获取未运行的可用模型
2、优化WEBUI模型列表显示与切换的控制逻辑
This commit is contained in:
liunux4odoo 2023-09-16 07:15:08 +08:00 committed by GitHub
parent 955b0bc211
commit 3dde02be28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 80 additions and 33 deletions

View File

@ -40,6 +40,7 @@ MODEL_PATH = {
"chatglm2-6b": "THUDM/chatglm2-6b", "chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4", "chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
}, },
} }

View File

@ -17,7 +17,7 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store, update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore) search_docs, DocumentWithScore)
from server.llm_api import list_llm_models, change_llm_model, stop_llm_model from server.llm_api import list_running_models,list_config_models, change_llm_model, stop_llm_model
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from typing import List from typing import List
@ -125,10 +125,15 @@ def create_app():
)(recreate_vector_store) )(recreate_vector_store)
# LLM模型相关接口 # LLM模型相关接口
app.post("/llm_model/list_models", app.post("/llm_model/list_running_models",
tags=["LLM Model Management"], tags=["LLM Model Management"],
summary="列出当前已加载的模型", summary="列出当前已加载的模型",
)(list_llm_models) )(list_running_models)
app.post("/llm_model/list_config_models",
tags=["LLM Model Management"],
summary="列出configs已配置的模型",
)(list_config_models)
app.post("/llm_model/stop", app.post("/llm_model/stop",
tags=["LLM Model Management"], tags=["LLM Model Management"],

View File

@ -1,10 +1,10 @@
from fastapi import Body from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import BaseResponse, fschat_controller_address from server.utils import BaseResponse, fschat_controller_address, list_llm_models
import httpx import httpx
def list_llm_models( def list_running_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]), controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]),
placeholder: str = Body(None, description="该参数未使用,占位用"), placeholder: str = Body(None, description="该参数未使用,占位用"),
) -> BaseResponse: ) -> BaseResponse:
@ -24,6 +24,13 @@ def list_llm_models(
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
def list_config_models() -> BaseResponse:
'''
从本地获取configs中配置的模型列表
'''
return BaseResponse(data=list_llm_models())
def stop_llm_model( def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]), model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])

View File

@ -72,7 +72,10 @@ def request_qianfan_api(
version_url = config.get("version_url") version_url = config.get("version_url")
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key")) access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
if not access_token: if not access_token:
raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?") yield {
"error_code": 403,
"error_msg": f"failed to get access token. have you set the correct api_key and secret key?",
}
url = BASE_URL.format( url = BASE_URL.format(
model_version=version_url or MODEL_VERSIONS[version], model_version=version_url or MODEL_VERSIONS[version],

View File

@ -5,12 +5,12 @@ from fastapi import FastAPI
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
logger, log_verbose, logger, log_verbose,
FSCHAT_MODEL_WORKERS) FSCHAT_MODEL_WORKERS)
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Literal, Optional, Callable, Generator, Dict, Any from typing import Literal, Optional, Callable, Generator, Dict, Any, Tuple
thread_pool = ThreadPoolExecutor(os.cpu_count()) thread_pool = ThreadPoolExecutor(os.cpu_count())
@ -201,10 +201,25 @@ def MakeFastAPIOffline(
# 从model_config中获取模型信息 # 从model_config中获取模型信息
def list_embed_models() -> List[str]: def list_embed_models() -> List[str]:
'''
get names of configured embedding models
'''
return list(MODEL_PATH["embed_model"]) return list(MODEL_PATH["embed_model"])
def list_llm_models() -> List[str]: def list_llm_models() -> Dict[str, List[str]]:
return list(MODEL_PATH["llm_model"]) '''
get names of configured llm models with different types.
return [(model_name, config_type), ...]
'''
workers = list(FSCHAT_MODEL_WORKERS)
if "default" in workers:
workers.remove("default")
return {
"local": list(MODEL_PATH["llm_model"]),
"online": list(ONLINE_LLM_MODEL),
"worker": workers,
}
def get_model_path(model_name: str, type: str = None) -> Optional[str]: def get_model_path(model_name: str, type: str = None) -> Optional[str]:
if type in MODEL_PATH: if type in MODEL_PATH:

View File

@ -68,6 +68,7 @@ def dialogue_page(api: ApiRequest):
config = get_model_worker_config(llm_model) config = get_model_worker_config(llm_model)
if not config.get("online_api"): # 只有本地model_worker可以切换模型 if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model st.session_state["prev_llm_model"] = llm_model
st.session_state["cur_llm_model"] = st.session_state.llm_model
def llm_model_format_func(x): def llm_model_format_func(x):
if x in running_models: if x in running_models:
@ -75,25 +76,32 @@ def dialogue_page(api: ApiRequest):
return x return x
running_models = api.list_running_models() running_models = api.list_running_models()
available_models = []
config_models = api.list_config_models() config_models = api.list_config_models()
for x in running_models: for models in config_models.values():
if x in config_models: for m in models:
config_models.remove(x) if m not in running_models:
llm_models = running_models + config_models available_models.append(m)
cur_model = st.session_state.get("cur_llm_model", LLM_MODEL) llm_models = running_models + available_models
index = llm_models.index(cur_model) index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL))
llm_model = st.selectbox("选择LLM模型", llm_model = st.selectbox("选择LLM模型",
llm_models, llm_models,
index, index,
format_func=llm_model_format_func, format_func=llm_model_format_func,
on_change=on_llm_change, on_change=on_llm_change,
# key="llm_model", key="llm_model",
) )
if (st.session_state.get("prev_llm_model") != llm_model if (st.session_state.get("prev_llm_model") != llm_model
and not get_model_worker_config(llm_model).get("online_api")): and not get_model_worker_config(llm_model).get("online_api")
and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model) prev_model = st.session_state.get("prev_llm_model")
st.session_state["cur_llm_model"] = llm_model r = api.change_llm_model(prev_model, llm_model)
if msg := check_error_msg(r):
st.error(msg)
elif msg := check_success_msg(r):
st.success(msg)
st.session_state["prev_llm_model"] = llm_model
temperature = st.slider("Temperature", 0.0, 1.0, TEMPERATURE, 0.05) temperature = st.slider("Temperature", 0.0, 1.0, TEMPERATURE, 0.05)
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN) history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)

View File

@ -766,23 +766,31 @@ class ApiRequest:
"controller_address": controller_address, "controller_address": controller_address,
} }
if no_remote_api: if no_remote_api:
from server.llm_api import list_llm_models from server.llm_api import list_running_models
return list_llm_models(**data).data return list_running_models(**data).data
else: else:
r = self.post( r = self.post(
"/llm_model/list_models", "/llm_model/list_running_models",
json=data, json=data,
) )
return r.json().get("data", []) return r.json().get("data", [])
def list_config_models(self): def list_config_models(self, no_remote_api: bool = None) -> Dict[str, List[str]]:
''' '''
获取configs中配置的模型列表 获取configs中配置的模型列表返回形式为{"type": [model_name1, model_name2, ...], ...}
如果no_remote_api=True, 从运行ApiRequest的机器上获取否则从运行api.py的机器上获取
''' '''
models = list(FSCHAT_MODEL_WORKERS.keys()) if no_remote_api is None:
if "default" in models: no_remote_api = self.no_remote_api
models.remove("default")
return models if no_remote_api:
from server.llm_api import list_config_models
return list_config_models().data
else:
r = self.post(
"/llm_model/list_config_models",
)
return r.json().get("data", {})
def stop_llm_model( def stop_llm_model(
self, self,
@ -828,13 +836,13 @@ class ApiRequest:
if not model_name or not new_model_name: if not model_name or not new_model_name:
return return
if new_model_name == model_name: running_models = self.list_running_models()
if new_model_name == model_name or new_model_name in running_models:
return { return {
"code": 200, "code": 200,
"msg": "什么都不用做" "msg": "无需切换"
} }
running_models = self.list_running_models()
if model_name not in running_models: if model_name not in running_models:
return { return {
"code": 500, "code": 500,
@ -842,7 +850,7 @@ class ApiRequest:
} }
config_models = self.list_config_models() config_models = self.list_config_models()
if new_model_name not in config_models: if new_model_name not in config_models.get("local", []):
return { return {
"code": 500, "code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"