From 2adfa4277c137eaaedeb6adafe3e98cea85dc445 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Tue, 14 Nov 2023 21:17:32 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9C=A8webui=E6=A8=A1=E5=9E=8B=E5=88=97?= =?UTF-8?q?=E8=A1=A8=E4=B8=AD=E5=8C=85=E6=8B=AC=EF=BC=9A=E9=9D=9Emodel=20w?= =?UTF-8?q?orker=E5=90=AF=E5=8A=A8=E7=9A=84=E5=9C=A8=E7=BA=BF=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=EF=BC=88=E5=A6=82openai-api=EF=BC=89=EF=BC=9B?= =?UTF-8?q?=E5=B7=B2=E7=BB=8F=E4=B8=8B=E8=BD=BD=E7=9A=84=E6=9C=AC=E5=9C=B0?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=20(#2060)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/llm_api.py | 25 +++++++++++-------------- server/utils.py | 14 +++++++++----- webui_pages/dialogue/dialogue.py | 16 ++++++++++------ webui_pages/utils.py | 11 +++++++++-- 4 files changed, 39 insertions(+), 27 deletions(-) diff --git a/server/llm_api.py b/server/llm_api.py index 015a1c0..fbac493 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -2,7 +2,7 @@ from fastapi import Body from configs import logger, log_verbose, LLM_MODELS, HTTPX_DEFAULT_TIMEOUT from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models, get_httpx_client, get_model_worker_config) -from copy import deepcopy +from typing import List def list_running_models( @@ -28,26 +28,23 @@ def list_running_models( msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") -def list_config_models() -> BaseResponse: +def list_config_models( + types: List[str] = Body(["local", "online"], description="模型配置项类别,如local, online, worker"), + placeholder: str = Body(None, description="占位用,无实际效果") +) -> BaseResponse: ''' 从本地获取configs中配置的模型列表 ''' - configs = {} - # 删除ONLINE_MODEL配置中的敏感信息 - for name, config in list_config_llm_models()["online"].items(): - configs[name] = {} - for k, v in config.items(): - if not (k == "worker_class" - or "key" in k.lower() - or "secret" in k.lower() - or k.lower().endswith("id")): - configs[name][k] = v - return BaseResponse(data=configs) + data = {} + for type, models in list_config_llm_models().items(): + if type in types: + data[type] = {m: get_model_config(m).data for m in models} + return BaseResponse(data=data) def get_model_config( model_name: str = Body(description="配置中LLM模型的名称"), - placeholder: str = Body(description="占位用,无实际效果") + placeholder: str = Body(None, description="占位用,无实际效果") ) -> BaseResponse: ''' 获取LLM模型配置项(合并后的) diff --git a/server/utils.py b/server/utils.py index 27187b5..ab307fa 100644 --- a/server/utils.py +++ b/server/utils.py @@ -342,13 +342,14 @@ def list_embed_models() -> List[str]: def list_config_llm_models() -> Dict[str, Dict]: ''' get configured llm models with different types. - return [(model_name, config_type), ...] + return {config_type: {model_name: config}, ...} ''' - workers = list(FSCHAT_MODEL_WORKERS) + workers = FSCHAT_MODEL_WORKERS.copy() + workers.pop("default", None) return { - "local": MODEL_PATH["llm_model"], - "online": ONLINE_LLM_MODEL, + "local": MODEL_PATH["llm_model"].copy(), + "online": ONLINE_LLM_MODEL.copy(), "worker": workers, } @@ -406,7 +407,10 @@ def get_model_worker_config(model_name: str = None) -> dict: exc_info=e if log_verbose else None) # 本地模型 if model_name in MODEL_PATH["llm_model"]: - config["model_path"] = get_model_path(model_name) + path = get_model_path(model_name) + config["model_path"] = path + if path and os.path.isdir(path): + config["model_path_exists"] = True config["device"] = llm_device(config.get("device")) return config diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 70448b0..7286cae 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -37,8 +37,8 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) -> def dialogue_page(api: ApiRequest, is_lite: bool = False): + default_model = api.get_default_llm_model()[0] if not chat_box.chat_inited: - default_model = api.get_default_llm_model()[0] st.toast( f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n" f"当前运行的模型`{default_model}`, 您可以开始提问了." @@ -83,15 +83,19 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): running_models = list(api.list_running_models()) available_models = [] config_models = api.list_config_models() - worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型 - for m in worker_models: - if m not in running_models and m != "default": - available_models.append(m) + for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型 + if (v.get("model_path_exists") + and k not in running_models): + available_models.append(k) for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型 if not v.get("provider") and k not in running_models: available_models.append(k) llm_models = running_models + available_models - index = llm_models.index(st.session_state.get("cur_llm_model", api.get_default_llm_model()[0])) + cur_llm_model = st.session_state.get("cur_llm_model", default_model) + if cur_llm_model in llm_models: + index = llm_models.index(cur_llm_model) + else: + index = 0 llm_model = st.selectbox("选择LLM模型:", llm_models, index, diff --git a/webui_pages/utils.py b/webui_pages/utils.py index b4a520e..bbae64e 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -730,12 +730,19 @@ class ApiRequest: else: return ret_sync() - def list_config_models(self) -> Dict[str, List[str]]: + def list_config_models( + self, + types: List[str] = ["local", "online"], + ) -> Dict[str, Dict]: ''' - 获取服务器configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。 + 获取服务器configs中配置的模型列表,返回形式为{"type": {model_name: config}, ...}。 ''' + data = { + "types": types, + } response = self.post( "/llm_model/list_config_models", + json=data, ) return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))