Update loader.py 解决加载ptuning报错的问题
This commit is contained in:
parent
ec2bf9757c
commit
c1c2ed1943
|
|
@ -441,7 +441,7 @@ class LoaderCheckPoint:
|
||||||
|
|
||||||
if self.use_ptuning_v2:
|
if self.use_ptuning_v2:
|
||||||
try:
|
try:
|
||||||
prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r')
|
prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r')
|
||||||
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
||||||
prefix_encoder_file.close()
|
prefix_encoder_file.close()
|
||||||
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
||||||
|
|
@ -457,13 +457,14 @@ class LoaderCheckPoint:
|
||||||
|
|
||||||
if self.use_ptuning_v2:
|
if self.use_ptuning_v2:
|
||||||
try:
|
try:
|
||||||
prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin'))
|
prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin'))
|
||||||
new_prefix_state_dict = {}
|
new_prefix_state_dict = {}
|
||||||
for k, v in prefix_state_dict.items():
|
for k, v in prefix_state_dict.items():
|
||||||
if k.startswith("transformer.prefix_encoder."):
|
if k.startswith("transformer.prefix_encoder."):
|
||||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||||
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||||
self.model.transformer.prefix_encoder.float()
|
self.model.transformer.prefix_encoder.float()
|
||||||
|
print("加载ptuning检查点成功!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
print("加载PrefixEncoder模型参数失败")
|
print("加载PrefixEncoder模型参数失败")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue