45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
|
|
from loguru import logger
|
|
|
|
import torch
|
|
|
|
import os
|
|
import shutil
|
|
|
|
|
|
def load_ckpt(model, ckpt):
|
|
model_state_dict = model.state_dict()
|
|
load_dict = {}
|
|
for key_model, v in model_state_dict.items():
|
|
if key_model not in ckpt:
|
|
logger.warning(
|
|
"{} is not in the ckpt. Please double check and see if this is desired.".format(
|
|
key_model
|
|
)
|
|
)
|
|
continue
|
|
v_ckpt = ckpt[key_model]
|
|
if v.shape != v_ckpt.shape:
|
|
logger.warning(
|
|
"Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
|
|
key_model, v_ckpt.shape, key_model, v.shape
|
|
)
|
|
)
|
|
continue
|
|
load_dict[key_model] = v_ckpt
|
|
|
|
model.load_state_dict(load_dict, strict=False)
|
|
return model
|
|
|
|
|
|
def save_checkpoint(state, is_best, save_dir, model_name=""):
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
filename = os.path.join(save_dir, model_name + "_ckpt.pth.tar")
|
|
torch.save(state, filename)
|
|
if is_best:
|
|
best_filename = os.path.join(save_dir, "best_ckpt.pth.tar")
|
|
shutil.copyfile(filename, best_filename)
|