From 3324c12d69b28460905e0e5d35921e6f3289bfcd Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Fri, 26 May 2023 22:52:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0cpu=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/loader/loader.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/models/loader/loader.py b/models/loader/loader.py index 02cf24e..47c365b 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -130,11 +130,8 @@ class LoaderCheckPoint: model = dispatch_model(model, device_map=self.device_map) else: - # print( - # "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been " - # "detected.\nFalling back to CPU mode.\n") model = ( - AutoModel.from_pretrained( + LoaderClass.from_pretrained( checkpoint, config=self.model_config, trust_remote_code=True) @@ -202,7 +199,11 @@ class LoaderCheckPoint: ) from exc # Custom else: - pass + + print( + "Warning: self.llm_device is False.\nThis means that no use GPU bring to be load CPU mode\n") + params = {"low_cpu_mem_usage": True, "torch_dtype": torch.float32, "trust_remote_code": True} + model = LoaderClass.from_pretrained(checkpoint, **params).to(self.llm_device, dtype=float) # Loading the tokenizer if type(model) is transformers.LlamaForCausalLM: