From 421ce3da3a935fbdeb595cfeb4b62415d6059858 Mon Sep 17 00:00:00 2001 From: Jingsong-Yan <75230787+Jingsong-Yan@users.noreply.github.com> Date: Fri, 30 Jun 2023 21:08:38 +0800 Subject: [PATCH] Add device_map config to support chatglm2-6b (#734) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit chatglm-6b和chatglm2-6b的参数命名不一致,本次提交旨在解决chatglm2-6b device_map 创建的问题。在chatglm_auto_configure_device_map 函数中新增了chatglm2-6b device_map 创建的相关代码。 --- models/loader/loader.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/models/loader/loader.py b/models/loader/loader.py index f315e6c..0c32835 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -257,10 +257,21 @@ class LoaderCheckPoint: # 在调用chat或者stream_chat时,input_ids会被放到model.device上 # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 - device_map = {f'{layer_prefix}.word_embeddings': 0, + + encode = "" + if 'chatglm2' in self.model_name: + device_map = { + f"{layer_prefix}.embedding.word_embeddings": 0, + f"{layer_prefix}.rotary_pos_emb": 0, + f"{layer_prefix}.output_layer": 0, + f"{layer_prefix}.encoder.final_layernorm": 0, + f"base_model.model.output_layer": 0 + } + encode = ".encoder" + else: + device_map = {f'{layer_prefix}.word_embeddings': 0, f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0, f'base_model.model.lm_head': 0, } - used = 2 gpu_target = 0 for i in range(num_trans_layers): @@ -268,7 +279,7 @@ class LoaderCheckPoint: gpu_target += 1 used = 0 assert gpu_target < num_gpus - device_map[f'{layer_prefix}.layers.{i}'] = gpu_target + device_map[f'{layer_prefix}{encode}.layers.{i}'] = gpu_target used += 1 return device_map