diff --git a/models/loader/loader.py b/models/loader/loader.py index f315e6c..9fe5ce9 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -140,7 +140,17 @@ class LoaderCheckPoint: elif 'moss' in model_name.lower(): self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name) else: - self.device_map = self.chatglm_auto_configure_device_map(num_gpus) + # 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败 + # 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡 + from accelerate.utils import get_balanced_memory + max_memory = get_balanced_memory(model, + dtype=torch.int8 if self.load_in_8bit else None, + low_zero=False, + no_split_module_classes=model._no_split_modules) + self.device_map = infer_auto_device_map(model, + dtype=torch.float16 if not self.load_in_8bit else torch.int8, + max_memory=max_memory, + no_split_module_classes=model._no_split_modules) model = dispatch_model(model, device_map=self.device_map) else: