73 lines
2.5 KiB
Python
73 lines
2.5 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import math
|
|
from copy import deepcopy
|
|
|
|
|
|
def is_parallel(model):
|
|
"""check if model is in parallel mode."""
|
|
|
|
parallel_type = (
|
|
nn.parallel.DataParallel,
|
|
nn.parallel.DistributedDataParallel,
|
|
)
|
|
return isinstance(model, parallel_type)
|
|
|
|
|
|
def copy_attr(a, b, include=(), exclude=()):
|
|
# Copy attributes from b to a, options to only include [...] and to exclude [...]
|
|
for k, v in b.__dict__.items():
|
|
if (len(include) and k not in include) or k.startswith("_") or k in exclude:
|
|
continue
|
|
else:
|
|
setattr(a, k, v)
|
|
|
|
|
|
class ModelEMA:
|
|
"""
|
|
Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
|
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
|
This is intended to allow functionality like
|
|
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
|
A smoothed version of the weights is necessary for some training schemes to perform well.
|
|
This class is sensitive where it is initialized in the sequence of model init,
|
|
GPU assignment and distributed training wrappers.
|
|
"""
|
|
|
|
def __init__(self, model, decay=0.9999, updates=0):
|
|
"""
|
|
Args:
|
|
model (nn.Module): model to apply EMA.
|
|
decay (float): ema decay reate.
|
|
updates (int): counter of EMA updates.
|
|
"""
|
|
# Create EMA(FP32)
|
|
self.ema = deepcopy(model.module if is_parallel(model) else model).eval()
|
|
self.updates = updates
|
|
# decay exponential ramp (to help early epochs)
|
|
self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
|
|
for p in self.ema.parameters():
|
|
p.requires_grad_(False)
|
|
|
|
def update(self, model):
|
|
# Update EMA parameters
|
|
with torch.no_grad():
|
|
self.updates += 1
|
|
d = self.decay(self.updates)
|
|
|
|
msd = (
|
|
model.module.state_dict() if is_parallel(model) else model.state_dict()
|
|
) # model state_dict
|
|
for k, v in self.ema.state_dict().items():
|
|
if v.dtype.is_floating_point:
|
|
v *= d
|
|
v += (1.0 - d) * msd[k].detach()
|
|
|
|
def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
|
|
# Update EMA attributes
|
|
copy_attr(self.ema, model, include, exclude)
|