update loader.py
This commit is contained in:
parent
5524c47645
commit
00d80335fe
|
|
@ -1,4 +1,4 @@
|
|||
|
||||
from .chatglm_llm import ChatGLM
|
||||
from .llama_llm import LLamaLLM
|
||||
# from .llama_llm import LLamaLLM
|
||||
from .moss_llm import MOSSLLM
|
||||
|
|
|
|||
|
|
@ -36,10 +36,6 @@ class LoaderCheckPoint:
|
|||
lora_dir: str = None
|
||||
ptuning_dir: str = None
|
||||
use_ptuning_v2: bool = False
|
||||
cpu: bool = False
|
||||
gpu_memory: object = None
|
||||
cpu_memory: object = None
|
||||
auto_devices: object = True
|
||||
# 如果开启了8bit量化加载,项目无法启动,参考此位置,选择合适的cuda版本,https://github.com/TimDettmers/bitsandbytes/issues/156
|
||||
load_in_8bit: bool = False
|
||||
is_llamacpp: bool = False
|
||||
|
|
@ -56,20 +52,16 @@ class LoaderCheckPoint:
|
|||
:param params:
|
||||
"""
|
||||
self.model_path = None
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.params = params or {}
|
||||
self.no_remote_model = params.get('no_remote_model', False)
|
||||
self.model_name = params.get('model', '')
|
||||
self.lora = params.get('lora', '')
|
||||
self.use_ptuning_v2 = params.get('use_ptuning_v2', False)
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.model_dir = params.get('model_dir', '')
|
||||
self.lora_dir = params.get('lora_dir', '')
|
||||
self.ptuning_dir = params.get('ptuning_dir', 'ptuning-v2')
|
||||
self.cpu = params.get('cpu', False)
|
||||
self.gpu_memory = params.get('gpu_memory', None)
|
||||
self.cpu_memory = params.get('cpu_memory', None)
|
||||
self.auto_devices = params.get('auto_devices', True)
|
||||
self.load_in_8bit = params.get('load_in_8bit', False)
|
||||
self.bf16 = params.get('bf16', False)
|
||||
|
||||
|
|
@ -111,8 +103,8 @@ class LoaderCheckPoint:
|
|||
LoaderClass = AutoModelForCausalLM
|
||||
|
||||
# Load the model in simple 16-bit mode by default
|
||||
if not any([self.cpu, self.load_in_8bit, self.auto_devices, self.gpu_memory is not None,
|
||||
self.cpu_memory is not None, self.is_llamacpp]):
|
||||
if not any([self.llm_device.lower()=="cpu",
|
||||
self.load_in_8bit, self.is_llamacpp]):
|
||||
|
||||
if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"):
|
||||
# 根据当前设备GPU数量决定是否进行多卡部署
|
||||
|
|
@ -140,14 +132,15 @@ class LoaderCheckPoint:
|
|||
if 'chatglm' in model_name.lower():
|
||||
device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
||||
elif 'moss' in model_name.lower():
|
||||
device_map = self.moss_auto_configure_device_map(num_gpus,model_name)
|
||||
device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
|
||||
else:
|
||||
device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
||||
|
||||
model = dispatch_model(model, device_map=device_map)
|
||||
else:
|
||||
print(
|
||||
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
|
||||
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been "
|
||||
"detected.\nFalling back to CPU mode.\n")
|
||||
model = (
|
||||
AutoModel.from_pretrained(
|
||||
checkpoint,
|
||||
|
|
@ -169,47 +162,20 @@ class LoaderCheckPoint:
|
|||
# Custom
|
||||
else:
|
||||
params = {"low_cpu_mem_usage": True}
|
||||
if not any((self.cpu, torch.cuda.is_available(), torch.has_mps)):
|
||||
print(
|
||||
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
|
||||
self.cpu = True
|
||||
|
||||
if self.cpu:
|
||||
params["torch_dtype"] = torch.float32
|
||||
if not self.llm_device.lower().startswith("cuda"):
|
||||
raise SystemError("8bit 模型需要 CUDA 支持,或者改用量化后模型!")
|
||||
else:
|
||||
params["device_map"] = 'auto'
|
||||
params["trust_remote_code"] = True
|
||||
if self.load_in_8bit and any((self.auto_devices, self.gpu_memory)):
|
||||
if self.load_in_8bit:
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=True)
|
||||
elif self.load_in_8bit:
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
llm_int8_enable_fp32_cpu_offload=False)
|
||||
elif self.bf16:
|
||||
params["torch_dtype"] = torch.bfloat16
|
||||
else:
|
||||
params["torch_dtype"] = torch.float16
|
||||
|
||||
if self.gpu_memory:
|
||||
memory_map = list(map(lambda x: x.strip(), self.gpu_memory))
|
||||
max_cpu_memory = self.cpu_memory.strip() if self.cpu_memory is not None else '99GiB'
|
||||
max_memory = {}
|
||||
for i in range(len(memory_map)):
|
||||
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else \
|
||||
memory_map[i]
|
||||
max_memory['cpu'] = max_cpu_memory
|
||||
params['max_memory'] = max_memory
|
||||
elif self.auto_devices:
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
||||
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||
if total_mem - suggestion < 800:
|
||||
suggestion -= 1000
|
||||
suggestion = int(round(suggestion / 1000))
|
||||
print(
|
||||
f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
||||
|
||||
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{self.cpu_memory or 99}GiB'}
|
||||
params['max_memory'] = max_memory
|
||||
|
||||
if self.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
||||
config = AutoConfig.from_pretrained(checkpoint)
|
||||
with init_empty_weights():
|
||||
|
|
@ -236,7 +202,8 @@ class LoaderCheckPoint:
|
|||
tokenizer.eos_token_id = 2
|
||||
tokenizer.bos_token_id = 1
|
||||
tokenizer.pad_token_id = 0
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
||||
|
|
@ -331,7 +298,7 @@ class LoaderCheckPoint:
|
|||
if len(lora_names) > 0:
|
||||
print("Applying the following LoRAs to {}: {}".format(self.model_name, ', '.join(lora_names)))
|
||||
params = {}
|
||||
if not self.cpu:
|
||||
if self.llm_device.lower() != "cpu":
|
||||
params['dtype'] = self.model.dtype
|
||||
if hasattr(self.model, "hf_device_map"):
|
||||
params['device_map'] = {"base_model.model." + k: v for k, v in self.model.hf_device_map.items()}
|
||||
|
|
@ -344,7 +311,7 @@ class LoaderCheckPoint:
|
|||
for lora in lora_names[1:]:
|
||||
self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora)
|
||||
|
||||
if not self.load_in_8bit and not self.cpu:
|
||||
if not self.load_in_8bit and self.llm_device.lower() != "cpu":
|
||||
|
||||
if not hasattr(self.model, "hf_device_map"):
|
||||
if torch.has_mps:
|
||||
|
|
@ -355,12 +322,23 @@ class LoaderCheckPoint:
|
|||
|
||||
def clear_torch_cache(self):
|
||||
gc.collect()
|
||||
if not self.cpu:
|
||||
device_id = "0" if torch.cuda.is_available() else None
|
||||
CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device
|
||||
with torch.cuda.device(CUDA_DEVICE):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
if self.llm_device.lower() != "cpu":
|
||||
if torch.has_mps:
|
||||
try:
|
||||
from torch.mps import empty_cache
|
||||
empty_cache()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(
|
||||
"如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
|
||||
elif torch.has_cuda:
|
||||
device_id = "0" if torch.cuda.is_available() else None
|
||||
CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device
|
||||
with torch.cuda.device(CUDA_DEVICE):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
else:
|
||||
print("未检测到 cuda 或 mps,暂不支持清理显存")
|
||||
|
||||
def unload_model(self):
|
||||
del self.model
|
||||
|
|
@ -382,7 +360,7 @@ class LoaderCheckPoint:
|
|||
prefix_encoder_file.close()
|
||||
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
||||
self.model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print("加载PrefixEncoder config.json失败")
|
||||
|
||||
self.model, self.tokenizer = self._load_model(self.model_name)
|
||||
|
|
@ -399,7 +377,7 @@ class LoaderCheckPoint:
|
|||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||
self.model.transformer.prefix_encoder.float()
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print("加载PrefixEncoder模型参数失败")
|
||||
|
||||
self.model = self.model.eval()
|
||||
|
|
|
|||
Loading…
Reference in New Issue