From ab73f6ad93b746a1b18d46469e9c4dc37784c5ca Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Wed, 31 May 2023 22:26:39 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=82=E9=85=8D=E8=BF=9C=E7=A8=8BLLM?= =?UTF-8?q?=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py | 10 ++++------ models/shared.py | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/configs/model_config.py b/configs/model_config.py index ba925b2..44f77e8 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -70,21 +70,19 @@ llm_model_dict = { "provides": "LLamaLLM" }, "fast-chat-chatglm-6b": { - "name": "FastChatOpenAI", + "name": "chatglm-6b", "pretrained_model_name": "FastChatOpenAI", "local_model_path": None, "provides": "FastChatOpenAILLM", - "api_base_url": "http://localhost:8000/v1", - "model_name": "chatglm-6b" + "api_base_url": "http://localhost:8000/v1" }, "fast-chat-vicuna-13b-hf": { - "name": "FastChatOpenAI", + "name": "vicuna-13b-hf", "pretrained_model_name": "vicuna-13b-hf", "local_model_path": None, "provides": "FastChatOpenAILLM", - "api_base_url": "http://localhost:8000/v1", - "model_name": "vicuna-13b-hf" + "api_base_url": "http://localhost:8000/v1" }, } diff --git a/models/shared.py b/models/shared.py index 0525750..8a76edb 100644 --- a/models/shared.py +++ b/models/shared.py @@ -34,7 +34,7 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_ loaderCheckPoint.model_path = llm_model_info["local_model_path"] - if 'FastChatOpenAILLM' in llm_model_info["local_model_path"]: + if 'FastChatOpenAILLM' in llm_model_info["provides"]: loaderCheckPoint.unload_model() else: loaderCheckPoint.reload_model() @@ -43,5 +43,5 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_ modelInsLLM = provides_class(checkPoint=loaderCheckPoint) if 'FastChatOpenAILLM' in llm_model_info["provides"]: modelInsLLM.set_api_base_url(llm_model_info['api_base_url']) - modelInsLLM.call_model_name(llm_model_info['model_name']) + modelInsLLM.call_model_name(llm_model_info['name']) return modelInsLLM