From 300d287d61b315802d7d9fd930cc438cb973101a Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Thu, 6 Jul 2023 04:47:44 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E9=BB=98=E8=AE=A4=E7=9A=84?= =?UTF-8?q?=E5=A4=9A=E5=8D=A1=E9=83=A8=E7=BD=B2=E6=96=B9=E6=A1=88=EF=BC=8C?= =?UTF-8?q?=E5=9F=BA=E6=9C=AC=E4=BF=9D=E8=AF=81=E9=92=88=E5=AF=B9=E6=96=B0?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E4=B9=9F=E4=B8=8D=E4=BC=9A=E5=A4=B1=E8=B4=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/loader/loader.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/models/loader/loader.py b/models/loader/loader.py index f315e6c..9fe5ce9 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -140,7 +140,17 @@ class LoaderCheckPoint: elif 'moss' in model_name.lower(): self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name) else: - self.device_map = self.chatglm_auto_configure_device_map(num_gpus) + # 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败 + # 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡 + from accelerate.utils import get_balanced_memory + max_memory = get_balanced_memory(model, + dtype=torch.int8 if self.load_in_8bit else None, + low_zero=False, + no_split_module_classes=model._no_split_modules) + self.device_map = infer_auto_device_map(model, + dtype=torch.float16 if not self.load_in_8bit else torch.int8, + max_memory=max_memory, + no_split_module_classes=model._no_split_modules) model = dispatch_model(model, device_map=self.device_map) else: