From 62ce5f07759db201aba78127d4d4d3a6b0ca52e4 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Fri, 19 May 2023 23:38:36 +0800 Subject: [PATCH] =?UTF-8?q?self.device=5Fmap=20=E5=8F=82=E6=95=B0=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E9=80=BB=E8=BE=91=20LLamaLLM=20=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/__init__.py | 2 +- models/loader/loader.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/models/__init__.py b/models/__init__.py index 4d2d683..2a58a8f 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,4 +1,4 @@ from .chatglm_llm import ChatGLM -# from .llama_llm import LLamaLLM +from .llama_llm import LLamaLLM from .moss_llm import MOSSLLM diff --git a/models/loader/loader.py b/models/loader/loader.py index 559031b..cc4457c 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -130,13 +130,13 @@ class LoaderCheckPoint: # 可传入device_map自定义每张卡的部署情况 if self.device_map is None: if 'chatglm' in model_name.lower(): - device_map = self.chatglm_auto_configure_device_map(num_gpus) + self.device_map = self.chatglm_auto_configure_device_map(num_gpus) elif 'moss' in model_name.lower(): - device_map = self.moss_auto_configure_device_map(num_gpus, model_name) + self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name) else: - device_map = self.chatglm_auto_configure_device_map(num_gpus) + self.device_map = self.chatglm_auto_configure_device_map(num_gpus) - model = dispatch_model(model, device_map=device_map) + model = dispatch_model(model, device_map=self.device_map) else: print( "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been "