107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from thop import profile
|
|
|
|
from copy import deepcopy
|
|
|
|
__all__ = [
|
|
"fuse_conv_and_bn",
|
|
"fuse_model",
|
|
"get_model_info",
|
|
"replace_module",
|
|
]
|
|
|
|
|
|
def get_model_info(model, tsize):
|
|
|
|
stride = 64
|
|
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
|
|
flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
|
|
params /= 1e6
|
|
flops /= 1e9
|
|
flops *= tsize[0] * tsize[1] / stride / stride * 2 # Gflops
|
|
info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops)
|
|
return info
|
|
|
|
|
|
def fuse_conv_and_bn(conv, bn):
|
|
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
|
fusedconv = (
|
|
nn.Conv2d(
|
|
conv.in_channels,
|
|
conv.out_channels,
|
|
kernel_size=conv.kernel_size,
|
|
stride=conv.stride,
|
|
padding=conv.padding,
|
|
groups=conv.groups,
|
|
bias=True,
|
|
)
|
|
.requires_grad_(False)
|
|
.to(conv.weight.device)
|
|
)
|
|
|
|
# prepare filters
|
|
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
|
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
|
|
|
# prepare spatial bias
|
|
b_conv = (
|
|
torch.zeros(conv.weight.size(0), device=conv.weight.device)
|
|
if conv.bias is None
|
|
else conv.bias
|
|
)
|
|
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
|
|
torch.sqrt(bn.running_var + bn.eps)
|
|
)
|
|
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
|
|
|
return fusedconv
|
|
|
|
|
|
def fuse_model(model):
|
|
from yolox.models.network_blocks import BaseConv
|
|
|
|
for m in model.modules():
|
|
if type(m) is BaseConv and hasattr(m, "bn"):
|
|
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
|
delattr(m, "bn") # remove batchnorm
|
|
m.forward = m.fuseforward # update forward
|
|
return model
|
|
|
|
|
|
def replace_module(module, replaced_module_type, new_module_type, replace_func=None):
|
|
"""
|
|
Replace given type in module to a new type. mostly used in deploy.
|
|
|
|
Args:
|
|
module (nn.Module): model to apply replace operation.
|
|
replaced_module_type (Type): module type to be replaced.
|
|
new_module_type (Type)
|
|
replace_func (function): python function to describe replace logic. Defalut value None.
|
|
|
|
Returns:
|
|
model (nn.Module): module that already been replaced.
|
|
"""
|
|
|
|
def default_replace_func(replaced_module_type, new_module_type):
|
|
return new_module_type()
|
|
|
|
if replace_func is None:
|
|
replace_func = default_replace_func
|
|
|
|
model = module
|
|
if isinstance(module, replaced_module_type):
|
|
model = replace_func(replaced_module_type, new_module_type)
|
|
else: # recurrsively replace
|
|
for name, child in module.named_children():
|
|
new_child = replace_module(child, replaced_module_type, new_module_type)
|
|
if new_child is not child: # child is already replaced
|
|
model.add_module(name, new_child)
|
|
|
|
return model
|