From 3dde02be28220d79e7905d8ecd277467ce808d5d Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Sat, 16 Sep 2023 07:15:08 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96LLM=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=88=97=E8=A1=A8=E8=8E=B7=E5=8F=96=E3=80=81=E5=88=87=E6=8D=A2?= =?UTF-8?q?=E7=9A=84=E9=80=BB=E8=BE=91=EF=BC=9A=20(#1497)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1、更准确的获取未运行的可用模型 2、优化WEBUI模型列表显示与切换的控制逻辑 --- configs/model_config.py.example | 1 + server/api.py | 11 ++++++++--- server/llm_api.py | 11 +++++++++-- server/model_workers/qianfan.py | 5 ++++- server/utils.py | 23 +++++++++++++++++---- webui_pages/dialogue/dialogue.py | 28 ++++++++++++++++---------- webui_pages/utils.py | 34 ++++++++++++++++++++------------ 7 files changed, 80 insertions(+), 33 deletions(-) diff --git a/configs/model_config.py.example b/configs/model_config.py.example index e377d9d..dff60f7 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -40,6 +40,7 @@ MODEL_PATH = { "chatglm2-6b": "THUDM/chatglm2-6b", "chatglm2-6b-int4": "THUDM/chatglm2-6b-int4", "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", + "baichuan-7b": "baichuan-inc/Baichuan-7B", }, } diff --git a/server/api.py b/server/api.py index 357a067..91326ff 100644 --- a/server/api.py +++ b/server/api.py @@ -17,7 +17,7 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, update_docs, download_doc, recreate_vector_store, search_docs, DocumentWithScore) -from server.llm_api import list_llm_models, change_llm_model, stop_llm_model +from server.llm_api import list_running_models,list_config_models, change_llm_model, stop_llm_model from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline from typing import List @@ -125,10 +125,15 @@ def create_app(): )(recreate_vector_store) # LLM模型相关接口 - app.post("/llm_model/list_models", + app.post("/llm_model/list_running_models", tags=["LLM Model Management"], summary="列出当前已加载的模型", - )(list_llm_models) + )(list_running_models) + + app.post("/llm_model/list_config_models", + tags=["LLM Model Management"], + summary="列出configs已配置的模型", + )(list_config_models) app.post("/llm_model/stop", tags=["LLM Model Management"], diff --git a/server/llm_api.py b/server/llm_api.py index 5843e89..a9e5ab6 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -1,10 +1,10 @@ from fastapi import Body from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT -from server.utils import BaseResponse, fschat_controller_address +from server.utils import BaseResponse, fschat_controller_address, list_llm_models import httpx -def list_llm_models( +def list_running_models( controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]), placeholder: str = Body(None, description="该参数未使用,占位用"), ) -> BaseResponse: @@ -24,6 +24,13 @@ def list_llm_models( msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") +def list_config_models() -> BaseResponse: + ''' + 从本地获取configs中配置的模型列表 + ''' + return BaseResponse(data=list_llm_models()) + + def stop_llm_model( model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]), controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 8a593a7..f7a5161 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -72,7 +72,10 @@ def request_qianfan_api( version_url = config.get("version_url") access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key")) if not access_token: - raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?") + yield { + "error_code": 403, + "error_msg": f"failed to get access token. have you set the correct api_key and secret key?", + } url = BASE_URL.format( model_version=version_url or MODEL_VERSIONS[version], diff --git a/server/utils.py b/server/utils.py index ec51ae0..4dc29ae 100644 --- a/server/utils.py +++ b/server/utils.py @@ -5,12 +5,12 @@ from fastapi import FastAPI from pathlib import Path import asyncio from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, - MODEL_PATH, MODEL_ROOT_PATH, + MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose, FSCHAT_MODEL_WORKERS) import os from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Literal, Optional, Callable, Generator, Dict, Any +from typing import Literal, Optional, Callable, Generator, Dict, Any, Tuple thread_pool = ThreadPoolExecutor(os.cpu_count()) @@ -201,10 +201,25 @@ def MakeFastAPIOffline( # 从model_config中获取模型信息 def list_embed_models() -> List[str]: + ''' + get names of configured embedding models + ''' return list(MODEL_PATH["embed_model"]) -def list_llm_models() -> List[str]: - return list(MODEL_PATH["llm_model"]) +def list_llm_models() -> Dict[str, List[str]]: + ''' + get names of configured llm models with different types. + return [(model_name, config_type), ...] + ''' + workers = list(FSCHAT_MODEL_WORKERS) + if "default" in workers: + workers.remove("default") + return { + "local": list(MODEL_PATH["llm_model"]), + "online": list(ONLINE_LLM_MODEL), + "worker": workers, + } + def get_model_path(model_name: str, type: str = None) -> Optional[str]: if type in MODEL_PATH: diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 50ad423..0d6a2ad 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -68,6 +68,7 @@ def dialogue_page(api: ApiRequest): config = get_model_worker_config(llm_model) if not config.get("online_api"): # 只有本地model_worker可以切换模型 st.session_state["prev_llm_model"] = llm_model + st.session_state["cur_llm_model"] = st.session_state.llm_model def llm_model_format_func(x): if x in running_models: @@ -75,25 +76,32 @@ def dialogue_page(api: ApiRequest): return x running_models = api.list_running_models() + available_models = [] config_models = api.list_config_models() - for x in running_models: - if x in config_models: - config_models.remove(x) - llm_models = running_models + config_models - cur_model = st.session_state.get("cur_llm_model", LLM_MODEL) - index = llm_models.index(cur_model) + for models in config_models.values(): + for m in models: + if m not in running_models: + available_models.append(m) + llm_models = running_models + available_models + index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL)) llm_model = st.selectbox("选择LLM模型:", llm_models, index, format_func=llm_model_format_func, on_change=on_llm_change, - # key="llm_model", + key="llm_model", ) if (st.session_state.get("prev_llm_model") != llm_model - and not get_model_worker_config(llm_model).get("online_api")): + and not get_model_worker_config(llm_model).get("online_api") + and llm_model not in running_models): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): - r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model) - st.session_state["cur_llm_model"] = llm_model + prev_model = st.session_state.get("prev_llm_model") + r = api.change_llm_model(prev_model, llm_model) + if msg := check_error_msg(r): + st.error(msg) + elif msg := check_success_msg(r): + st.success(msg) + st.session_state["prev_llm_model"] = llm_model temperature = st.slider("Temperature:", 0.0, 1.0, TEMPERATURE, 0.05) history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN) diff --git a/webui_pages/utils.py b/webui_pages/utils.py index c23a8ea..5708efc 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -766,23 +766,31 @@ class ApiRequest: "controller_address": controller_address, } if no_remote_api: - from server.llm_api import list_llm_models - return list_llm_models(**data).data + from server.llm_api import list_running_models + return list_running_models(**data).data else: r = self.post( - "/llm_model/list_models", + "/llm_model/list_running_models", json=data, ) return r.json().get("data", []) - def list_config_models(self): + def list_config_models(self, no_remote_api: bool = None) -> Dict[str, List[str]]: ''' - 获取configs中配置的模型列表 + 获取configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。 + 如果no_remote_api=True, 从运行ApiRequest的机器上获取;否则从运行api.py的机器上获取。 ''' - models = list(FSCHAT_MODEL_WORKERS.keys()) - if "default" in models: - models.remove("default") - return models + if no_remote_api is None: + no_remote_api = self.no_remote_api + + if no_remote_api: + from server.llm_api import list_config_models + return list_config_models().data + else: + r = self.post( + "/llm_model/list_config_models", + ) + return r.json().get("data", {}) def stop_llm_model( self, @@ -828,13 +836,13 @@ class ApiRequest: if not model_name or not new_model_name: return - if new_model_name == model_name: + running_models = self.list_running_models() + if new_model_name == model_name or new_model_name in running_models: return { "code": 200, - "msg": "什么都不用做" + "msg": "无需切换" } - running_models = self.list_running_models() if model_name not in running_models: return { "code": 500, @@ -842,7 +850,7 @@ class ApiRequest: } config_models = self.list_config_models() - if new_model_name not in config_models: + if new_model_name not in config_models.get("local", []): return { "code": 500, "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"