104 lines
2.8 KiB
Python
104 lines
2.8 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
|
|
|
|
import torch
|
|
from torch import distributed as dist
|
|
from torch import nn
|
|
|
|
import pickle
|
|
from collections import OrderedDict
|
|
|
|
from .dist import _get_global_gloo_group, get_world_size
|
|
|
|
ASYNC_NORM = (
|
|
nn.BatchNorm1d,
|
|
nn.BatchNorm2d,
|
|
nn.BatchNorm3d,
|
|
nn.InstanceNorm1d,
|
|
nn.InstanceNorm2d,
|
|
nn.InstanceNorm3d,
|
|
)
|
|
|
|
__all__ = [
|
|
"get_async_norm_states",
|
|
"pyobj2tensor",
|
|
"tensor2pyobj",
|
|
"all_reduce",
|
|
"all_reduce_norm",
|
|
]
|
|
|
|
|
|
def get_async_norm_states(module):
|
|
async_norm_states = OrderedDict()
|
|
for name, child in module.named_modules():
|
|
if isinstance(child, ASYNC_NORM):
|
|
for k, v in child.state_dict().items():
|
|
async_norm_states[".".join([name, k])] = v
|
|
return async_norm_states
|
|
|
|
|
|
def pyobj2tensor(pyobj, device="cuda"):
|
|
"""serialize picklable python object to tensor"""
|
|
storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
|
|
return torch.ByteTensor(storage).to(device=device)
|
|
|
|
|
|
def tensor2pyobj(tensor):
|
|
"""deserialize tensor to picklable python object"""
|
|
return pickle.loads(tensor.cpu().numpy().tobytes())
|
|
|
|
|
|
def _get_reduce_op(op_name):
|
|
return {
|
|
"sum": dist.ReduceOp.SUM,
|
|
"mean": dist.ReduceOp.SUM,
|
|
}[op_name.lower()]
|
|
|
|
|
|
def all_reduce(py_dict, op="sum", group=None):
|
|
"""
|
|
Apply all reduce function for python dict object.
|
|
NOTE: make sure that every py_dict has the same keys and values are in the same shape.
|
|
|
|
Args:
|
|
py_dict (dict): dict to apply all reduce op.
|
|
op (str): operator, could be "sum" or "mean".
|
|
"""
|
|
world_size = get_world_size()
|
|
if world_size == 1:
|
|
return py_dict
|
|
if group is None:
|
|
group = _get_global_gloo_group()
|
|
if dist.get_world_size(group) == 1:
|
|
return py_dict
|
|
|
|
# all reduce logic across different devices.
|
|
py_key = list(py_dict.keys())
|
|
py_key_tensor = pyobj2tensor(py_key)
|
|
dist.broadcast(py_key_tensor, src=0)
|
|
py_key = tensor2pyobj(py_key_tensor)
|
|
|
|
tensor_shapes = [py_dict[k].shape for k in py_key]
|
|
tensor_numels = [py_dict[k].numel() for k in py_key]
|
|
|
|
flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
|
|
dist.all_reduce(flatten_tensor, op=_get_reduce_op(op))
|
|
if op == "mean":
|
|
flatten_tensor /= world_size
|
|
|
|
split_tensors = [
|
|
x.reshape(shape)
|
|
for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes)
|
|
]
|
|
return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})
|
|
|
|
|
|
def all_reduce_norm(module):
|
|
"""
|
|
All reduce norm statistics in different devices.
|
|
"""
|
|
states = get_async_norm_states(module)
|
|
states = all_reduce(states, op="mean")
|
|
module.load_state_dict(states, strict=False)
|