From d5ffdaa281933c153eef50a38b47c223176ddc8a Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sat, 20 May 2023 00:06:41 +0800 Subject: [PATCH] update loader.py --- models/loader/loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/loader/loader.py b/models/loader/loader.py index cc4457c..b6608f3 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -16,6 +16,7 @@ from transformers.modeling_utils import no_init_weights from transformers.utils import ContextManagers from accelerate import init_empty_weights from accelerate.utils import get_balanced_memory, infer_auto_device_map +from configs.model_config import LLM_DEVICE class LoaderCheckPoint: @@ -44,7 +45,7 @@ class LoaderCheckPoint: # 自定义设备网络 device_map: Optional[Dict[str, int]] = None # 默认 cuda ,如果不支持cuda使用多卡, 如果不支持多卡 使用cpu - llm_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + llm_device = LLM_DEVICE def __init__(self, params: dict = None): """