self.device_map 参数初始化逻辑

LLamaLLM 加载器
This commit is contained in:
glide-the 2023-05-19 23:38:36 +08:00
parent 49e47231af
commit 62ce5f0775
2 changed files with 5 additions and 5 deletions

View File

@ -1,4 +1,4 @@
from .chatglm_llm import ChatGLM from .chatglm_llm import ChatGLM
# from .llama_llm import LLamaLLM from .llama_llm import LLamaLLM
from .moss_llm import MOSSLLM from .moss_llm import MOSSLLM

View File

@ -130,13 +130,13 @@ class LoaderCheckPoint:
# 可传入device_map自定义每张卡的部署情况 # 可传入device_map自定义每张卡的部署情况
if self.device_map is None: if self.device_map is None:
if 'chatglm' in model_name.lower(): if 'chatglm' in model_name.lower():
device_map = self.chatglm_auto_configure_device_map(num_gpus) self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
elif 'moss' in model_name.lower(): elif 'moss' in model_name.lower():
device_map = self.moss_auto_configure_device_map(num_gpus, model_name) self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
else: else:
device_map = self.chatglm_auto_configure_device_map(num_gpus) self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
model = dispatch_model(model, device_map=device_map) model = dispatch_model(model, device_map=self.device_map)
else: else:
print( print(
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been " "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been "