330 lines
11 KiB
Python
330 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
# Copyright (c) Megvii, Inc. and its affiliates.
|
|
|
|
from loguru import logger
|
|
|
|
import torch
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from yolox.data import DataPrefetcher
|
|
from yolox.utils import (
|
|
MeterBuffer,
|
|
ModelEMA,
|
|
all_reduce_norm,
|
|
get_model_info,
|
|
get_rank,
|
|
get_world_size,
|
|
gpu_mem_usage,
|
|
load_ckpt,
|
|
occupy_mem,
|
|
save_checkpoint,
|
|
setup_logger,
|
|
synchronize
|
|
)
|
|
|
|
import datetime
|
|
import os
|
|
import time
|
|
|
|
|
|
class Trainer:
|
|
def __init__(self, exp, args):
|
|
# init function only defines some basic attr, other attrs like model, optimizer are built in
|
|
# before_train methods.
|
|
self.exp = exp
|
|
self.args = args
|
|
|
|
# training related attr
|
|
self.max_epoch = exp.max_epoch
|
|
self.amp_training = args.fp16
|
|
self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
|
|
self.is_distributed = get_world_size() > 1
|
|
self.rank = get_rank()
|
|
self.local_rank = args.local_rank
|
|
self.device = "cuda:{}".format(self.local_rank)
|
|
self.use_model_ema = exp.ema
|
|
|
|
# data/dataloader related attr
|
|
self.data_type = torch.float16 if args.fp16 else torch.float32
|
|
self.input_size = exp.input_size
|
|
self.best_ap = 0
|
|
|
|
# metric record
|
|
self.meter = MeterBuffer(window_size=exp.print_interval)
|
|
self.file_name = os.path.join(exp.output_dir, args.experiment_name)
|
|
|
|
if self.rank == 0:
|
|
os.makedirs(self.file_name, exist_ok=True)
|
|
|
|
setup_logger(
|
|
self.file_name,
|
|
distributed_rank=self.rank,
|
|
filename="train_log.txt",
|
|
mode="a",
|
|
)
|
|
|
|
def train(self):
|
|
self.before_train()
|
|
try:
|
|
self.train_in_epoch()
|
|
except Exception:
|
|
raise
|
|
finally:
|
|
self.after_train()
|
|
|
|
def train_in_epoch(self):
|
|
for self.epoch in range(self.start_epoch, self.max_epoch):
|
|
self.before_epoch()
|
|
self.train_in_iter()
|
|
self.after_epoch()
|
|
|
|
def train_in_iter(self):
|
|
for self.iter in range(self.max_iter):
|
|
self.before_iter()
|
|
self.train_one_iter()
|
|
self.after_iter()
|
|
|
|
def train_one_iter(self):
|
|
iter_start_time = time.time()
|
|
|
|
inps, targets = self.prefetcher.next()
|
|
track_ids = targets[:, :, 5]
|
|
targets = targets[:, :, :5]
|
|
inps = inps.to(self.data_type)
|
|
targets = targets.to(self.data_type)
|
|
targets.requires_grad = False
|
|
data_end_time = time.time()
|
|
|
|
with torch.cuda.amp.autocast(enabled=self.amp_training):
|
|
outputs = self.model(inps, targets)
|
|
loss = outputs["total_loss"]
|
|
|
|
self.optimizer.zero_grad()
|
|
self.scaler.scale(loss).backward()
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
|
|
if self.use_model_ema:
|
|
self.ema_model.update(self.model)
|
|
|
|
lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1)
|
|
for param_group in self.optimizer.param_groups:
|
|
param_group["lr"] = lr
|
|
|
|
iter_end_time = time.time()
|
|
self.meter.update(
|
|
iter_time=iter_end_time - iter_start_time,
|
|
data_time=data_end_time - iter_start_time,
|
|
lr=lr,
|
|
**outputs,
|
|
)
|
|
|
|
def before_train(self):
|
|
logger.info("args: {}".format(self.args))
|
|
logger.info("exp value:\n{}".format(self.exp))
|
|
|
|
# model related init
|
|
torch.cuda.set_device(self.local_rank)
|
|
model = self.exp.get_model()
|
|
logger.info(
|
|
"Model Summary: {}".format(get_model_info(model, self.exp.test_size))
|
|
)
|
|
model.to(self.device)
|
|
|
|
# solver related init
|
|
self.optimizer = self.exp.get_optimizer(self.args.batch_size)
|
|
|
|
# value of epoch will be set in `resume_train`
|
|
model = self.resume_train(model)
|
|
|
|
# data related init
|
|
self.no_aug = self.start_epoch >= self.max_epoch - self.exp.no_aug_epochs
|
|
self.train_loader = self.exp.get_data_loader(
|
|
batch_size=self.args.batch_size,
|
|
is_distributed=self.is_distributed,
|
|
no_aug=self.no_aug,
|
|
)
|
|
logger.info("init prefetcher, this might take one minute or less...")
|
|
self.prefetcher = DataPrefetcher(self.train_loader)
|
|
# max_iter means iters per epoch
|
|
self.max_iter = len(self.train_loader)
|
|
|
|
self.lr_scheduler = self.exp.get_lr_scheduler(
|
|
self.exp.basic_lr_per_img * self.args.batch_size, self.max_iter
|
|
)
|
|
if self.args.occupy:
|
|
occupy_mem(self.local_rank)
|
|
|
|
if self.is_distributed:
|
|
model = DDP(model, device_ids=[self.local_rank], broadcast_buffers=False)
|
|
|
|
if self.use_model_ema:
|
|
self.ema_model = ModelEMA(model, 0.9998)
|
|
self.ema_model.updates = self.max_iter * self.start_epoch
|
|
|
|
self.model = model
|
|
self.model.train()
|
|
|
|
self.evaluator = self.exp.get_evaluator(
|
|
batch_size=self.args.batch_size, is_distributed=self.is_distributed
|
|
)
|
|
# Tensorboard logger
|
|
if self.rank == 0:
|
|
self.tblogger = SummaryWriter(self.file_name)
|
|
|
|
logger.info("Training start...")
|
|
#logger.info("\n{}".format(model))
|
|
|
|
def after_train(self):
|
|
logger.info(
|
|
"Training of experiment is done and the best AP is {:.2f}".format(
|
|
self.best_ap * 100
|
|
)
|
|
)
|
|
|
|
def before_epoch(self):
|
|
logger.info("---> start train epoch{}".format(self.epoch + 1))
|
|
|
|
if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug:
|
|
|
|
logger.info("--->No mosaic aug now!")
|
|
self.train_loader.close_mosaic()
|
|
logger.info("--->Add additional L1 loss now!")
|
|
if self.is_distributed:
|
|
self.model.module.head.use_l1 = True
|
|
else:
|
|
self.model.head.use_l1 = True
|
|
|
|
self.exp.eval_interval = 1
|
|
if not self.no_aug:
|
|
self.save_ckpt(ckpt_name="last_mosaic_epoch")
|
|
|
|
def after_epoch(self):
|
|
if self.use_model_ema:
|
|
self.ema_model.update_attr(self.model)
|
|
|
|
self.save_ckpt(ckpt_name="latest")
|
|
|
|
if (self.epoch + 1) % self.exp.eval_interval == 0:
|
|
all_reduce_norm(self.model)
|
|
self.evaluate_and_save_model()
|
|
|
|
def before_iter(self):
|
|
pass
|
|
|
|
def after_iter(self):
|
|
"""
|
|
`after_iter` contains two parts of logic:
|
|
* log information
|
|
* reset setting of resize
|
|
"""
|
|
# log needed information
|
|
if (self.iter + 1) % self.exp.print_interval == 0:
|
|
# TODO check ETA logic
|
|
left_iters = self.max_iter * self.max_epoch - (self.progress_in_iter + 1)
|
|
eta_seconds = self.meter["iter_time"].global_avg * left_iters
|
|
eta_str = "ETA: {}".format(datetime.timedelta(seconds=int(eta_seconds)))
|
|
|
|
progress_str = "epoch: {}/{}, iter: {}/{}".format(
|
|
self.epoch + 1, self.max_epoch, self.iter + 1, self.max_iter
|
|
)
|
|
loss_meter = self.meter.get_filtered_meter("loss")
|
|
loss_str = ", ".join(
|
|
["{}: {:.3f}".format(k, v.latest) for k, v in loss_meter.items()]
|
|
)
|
|
|
|
time_meter = self.meter.get_filtered_meter("time")
|
|
time_str = ", ".join(
|
|
["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()]
|
|
)
|
|
|
|
logger.info(
|
|
"{}, mem: {:.0f}Mb, {}, {}, lr: {:.3e}".format(
|
|
progress_str,
|
|
gpu_mem_usage(),
|
|
time_str,
|
|
loss_str,
|
|
self.meter["lr"].latest,
|
|
)
|
|
+ (", size: {:d}, {}".format(self.input_size[0], eta_str))
|
|
)
|
|
self.meter.clear_meters()
|
|
|
|
# random resizing
|
|
if self.exp.random_size is not None and (self.progress_in_iter + 1) % 10 == 0:
|
|
self.input_size = self.exp.random_resize(
|
|
self.train_loader, self.epoch, self.rank, self.is_distributed
|
|
)
|
|
|
|
@property
|
|
def progress_in_iter(self):
|
|
return self.epoch * self.max_iter + self.iter
|
|
|
|
def resume_train(self, model):
|
|
if self.args.resume:
|
|
logger.info("resume training")
|
|
if self.args.ckpt is None:
|
|
ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth.tar")
|
|
else:
|
|
ckpt_file = self.args.ckpt
|
|
|
|
ckpt = torch.load(ckpt_file, map_location=self.device)
|
|
# resume the model/optimizer state dict
|
|
model.load_state_dict(ckpt["model"])
|
|
self.optimizer.load_state_dict(ckpt["optimizer"])
|
|
start_epoch = (
|
|
self.args.start_epoch - 1
|
|
if self.args.start_epoch is not None
|
|
else ckpt["start_epoch"]
|
|
)
|
|
self.start_epoch = start_epoch
|
|
logger.info(
|
|
"loaded checkpoint '{}' (epoch {})".format(
|
|
self.args.resume, self.start_epoch
|
|
)
|
|
) # noqa
|
|
else:
|
|
if self.args.ckpt is not None:
|
|
logger.info("loading checkpoint for fine tuning")
|
|
ckpt_file = self.args.ckpt
|
|
ckpt = torch.load(ckpt_file, map_location=self.device)["model"]
|
|
model = load_ckpt(model, ckpt)
|
|
self.start_epoch = 0
|
|
|
|
return model
|
|
|
|
def evaluate_and_save_model(self):
|
|
evalmodel = self.ema_model.ema if self.use_model_ema else self.model
|
|
ap50_95, ap50, summary = self.exp.eval(
|
|
evalmodel, self.evaluator, self.is_distributed
|
|
)
|
|
self.model.train()
|
|
if self.rank == 0:
|
|
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
|
|
self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
|
|
logger.info("\n" + summary)
|
|
synchronize()
|
|
|
|
#self.best_ap = max(self.best_ap, ap50_95)
|
|
self.save_ckpt("last_epoch", ap50 > self.best_ap)
|
|
self.best_ap = max(self.best_ap, ap50)
|
|
|
|
def save_ckpt(self, ckpt_name, update_best_ckpt=False):
|
|
if self.rank == 0:
|
|
save_model = self.ema_model.ema if self.use_model_ema else self.model
|
|
logger.info("Save weights to {}".format(self.file_name))
|
|
ckpt_state = {
|
|
"start_epoch": self.epoch + 1,
|
|
"model": save_model.state_dict(),
|
|
"optimizer": self.optimizer.state_dict(),
|
|
}
|
|
save_checkpoint(
|
|
ckpt_state,
|
|
update_best_ckpt,
|
|
self.file_name,
|
|
ckpt_name,
|
|
)
|