diff --git a/configs/model_config.py b/configs/model_config.py index ba925b2..44f77e8 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -70,21 +70,19 @@ llm_model_dict = { "provides": "LLamaLLM" }, "fast-chat-chatglm-6b": { - "name": "FastChatOpenAI", + "name": "chatglm-6b", "pretrained_model_name": "FastChatOpenAI", "local_model_path": None, "provides": "FastChatOpenAILLM", - "api_base_url": "http://localhost:8000/v1", - "model_name": "chatglm-6b" + "api_base_url": "http://localhost:8000/v1" }, "fast-chat-vicuna-13b-hf": { - "name": "FastChatOpenAI", + "name": "vicuna-13b-hf", "pretrained_model_name": "vicuna-13b-hf", "local_model_path": None, "provides": "FastChatOpenAILLM", - "api_base_url": "http://localhost:8000/v1", - "model_name": "vicuna-13b-hf" + "api_base_url": "http://localhost:8000/v1" }, } diff --git a/models/shared.py b/models/shared.py index 0525750..8a76edb 100644 --- a/models/shared.py +++ b/models/shared.py @@ -34,7 +34,7 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_ loaderCheckPoint.model_path = llm_model_info["local_model_path"] - if 'FastChatOpenAILLM' in llm_model_info["local_model_path"]: + if 'FastChatOpenAILLM' in llm_model_info["provides"]: loaderCheckPoint.unload_model() else: loaderCheckPoint.reload_model() @@ -43,5 +43,5 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_ modelInsLLM = provides_class(checkPoint=loaderCheckPoint) if 'FastChatOpenAILLM' in llm_model_info["provides"]: modelInsLLM.set_api_base_url(llm_model_info['api_base_url']) - modelInsLLM.call_model_name(llm_model_info['model_name']) + modelInsLLM.call_model_name(llm_model_info['name']) return modelInsLLM