update loader.py

This commit is contained in:
imClumsyPanda 2023-05-20 00:06:41 +08:00
parent aa26645407
commit d5ffdaa281
1 changed files with 2 additions and 1 deletions

View File

@ -16,6 +16,7 @@ from transformers.modeling_utils import no_init_weights
from transformers.utils import ContextManagers
from accelerate import init_empty_weights
from accelerate.utils import get_balanced_memory, infer_auto_device_map
from configs.model_config import LLM_DEVICE
class LoaderCheckPoint:
@ -44,7 +45,7 @@ class LoaderCheckPoint:
# 自定义设备网络
device_map: Optional[Dict[str, int]] = None
# 默认 cuda 如果不支持cuda使用多卡 如果不支持多卡 使用cpu
llm_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
llm_device = LLM_DEVICE
def __init__(self, params: dict = None):
"""