diff --git a/models/loader/loader.py b/models/loader/loader.py index 986f5c4..f43bb1d 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -441,7 +441,7 @@ class LoaderCheckPoint: if self.use_ptuning_v2: try: - prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r') + prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r') prefix_encoder_config = json.loads(prefix_encoder_file.read()) prefix_encoder_file.close() self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] @@ -457,13 +457,14 @@ class LoaderCheckPoint: if self.use_ptuning_v2: try: - prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin')) + prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin')) new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): if k.startswith("transformer.prefix_encoder."): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.float() + print("加载ptuning检查点成功!") except Exception as e: print(e) print("加载PrefixEncoder模型参数失败")