From c1c2ed19435a6dc827da4acc1d8a701260f0caae Mon Sep 17 00:00:00 2001 From: chinainfant <38291328+chinainfant@users.noreply.github.com> Date: Fri, 21 Jul 2023 15:30:19 +0800 Subject: [PATCH] =?UTF-8?q?Update=20loader.py=20=E8=A7=A3=E5=86=B3?= =?UTF-8?q?=E5=8A=A0=E8=BD=BDptuning=E6=8A=A5=E9=94=99=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/loader/loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/models/loader/loader.py b/models/loader/loader.py index 986f5c4..f43bb1d 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -441,7 +441,7 @@ class LoaderCheckPoint: if self.use_ptuning_v2: try: - prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r') + prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r') prefix_encoder_config = json.loads(prefix_encoder_file.read()) prefix_encoder_file.close() self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] @@ -457,13 +457,14 @@ class LoaderCheckPoint: if self.use_ptuning_v2: try: - prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin')) + prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin')) new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): if k.startswith("transformer.prefix_encoder."): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.float() + print("加载ptuning检查点成功!") except Exception as e: print(e) print("加载PrefixEncoder模型参数失败")