update loader.py
This commit is contained in:
parent
aa26645407
commit
d5ffdaa281
|
|
@ -16,6 +16,7 @@ from transformers.modeling_utils import no_init_weights
|
||||||
from transformers.utils import ContextManagers
|
from transformers.utils import ContextManagers
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
||||||
|
from configs.model_config import LLM_DEVICE
|
||||||
|
|
||||||
|
|
||||||
class LoaderCheckPoint:
|
class LoaderCheckPoint:
|
||||||
|
|
@ -44,7 +45,7 @@ class LoaderCheckPoint:
|
||||||
# 自定义设备网络
|
# 自定义设备网络
|
||||||
device_map: Optional[Dict[str, int]] = None
|
device_map: Optional[Dict[str, int]] = None
|
||||||
# 默认 cuda ,如果不支持cuda使用多卡, 如果不支持多卡 使用cpu
|
# 默认 cuda ,如果不支持cuda使用多卡, 如果不支持多卡 使用cpu
|
||||||
llm_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
llm_device = LLM_DEVICE
|
||||||
|
|
||||||
def __init__(self, params: dict = None):
|
def __init__(self, params: dict = None):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue