diff --git a/models/loader/loader.py b/models/loader/loader.py index 559031b..cc4457c 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -130,13 +130,13 @@ class LoaderCheckPoint: # 可传入device_map自定义每张卡的部署情况 if self.device_map is None: if 'chatglm' in model_name.lower(): - device_map = self.chatglm_auto_configure_device_map(num_gpus) + self.device_map = self.chatglm_auto_configure_device_map(num_gpus) elif 'moss' in model_name.lower(): - device_map = self.moss_auto_configure_device_map(num_gpus, model_name) + self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name) else: - device_map = self.chatglm_auto_configure_device_map(num_gpus) + self.device_map = self.chatglm_auto_configure_device_map(num_gpus) - model = dispatch_model(model, device_map=device_map) + model = dispatch_model(model, device_map=self.device_map) else: print( "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been "