diff --git a/models/loader/loader.py b/models/loader/loader.py index f315e6c..0c32835 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -257,10 +257,21 @@ class LoaderCheckPoint: # 在调用chat或者stream_chat时,input_ids会被放到model.device上 # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 - device_map = {f'{layer_prefix}.word_embeddings': 0, + + encode = "" + if 'chatglm2' in self.model_name: + device_map = { + f"{layer_prefix}.embedding.word_embeddings": 0, + f"{layer_prefix}.rotary_pos_emb": 0, + f"{layer_prefix}.output_layer": 0, + f"{layer_prefix}.encoder.final_layernorm": 0, + f"base_model.model.output_layer": 0 + } + encode = ".encoder" + else: + device_map = {f'{layer_prefix}.word_embeddings': 0, f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0, f'base_model.model.lm_head': 0, } - used = 2 gpu_target = 0 for i in range(num_trans_layers): @@ -268,7 +279,7 @@ class LoaderCheckPoint: gpu_target += 1 used = 0 assert gpu_target < num_gpus - device_map[f'{layer_prefix}.layers.{i}'] = gpu_target + device_map[f'{layer_prefix}{encode}.layers.{i}'] = gpu_target used += 1 return device_map