90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
# Copyright (c) Megvii, Inc. and its affiliates.
|
|
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from yolox.exp import Exp as MyExp
|
|
|
|
|
|
class Exp(MyExp):
|
|
def __init__(self):
|
|
super(Exp, self).__init__()
|
|
self.depth = 1.0
|
|
self.width = 1.0
|
|
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
|
|
|
def get_model(self, sublinear=False):
|
|
def init_yolo(M):
|
|
for m in M.modules():
|
|
if isinstance(m, nn.BatchNorm2d):
|
|
m.eps = 1e-3
|
|
m.momentum = 0.03
|
|
if "model" not in self.__dict__:
|
|
from yolox.models import YOLOX, YOLOFPN, YOLOXHead
|
|
backbone = YOLOFPN()
|
|
head = YOLOXHead(self.num_classes, self.width, in_channels=[128, 256, 512], act="lrelu")
|
|
self.model = YOLOX(backbone, head)
|
|
self.model.apply(init_yolo)
|
|
self.model.head.initialize_biases(1e-2)
|
|
|
|
return self.model
|
|
|
|
def get_data_loader(self, batch_size, is_distributed, no_aug=False):
|
|
from data.datasets.cocodataset import COCODataset
|
|
from data.datasets.mosaicdetection import MosaicDetection
|
|
from data.datasets.data_augment import TrainTransform
|
|
from data.datasets.dataloading import YoloBatchSampler, DataLoader, InfiniteSampler
|
|
import torch.distributed as dist
|
|
|
|
dataset = COCODataset(
|
|
data_dir='data/COCO/',
|
|
json_file=self.train_ann,
|
|
img_size=self.input_size,
|
|
preproc=TrainTransform(
|
|
rgb_means=(0.485, 0.456, 0.406),
|
|
std=(0.229, 0.224, 0.225),
|
|
max_labels=50
|
|
),
|
|
)
|
|
|
|
dataset = MosaicDetection(
|
|
dataset,
|
|
mosaic=not no_aug,
|
|
img_size=self.input_size,
|
|
preproc=TrainTransform(
|
|
rgb_means=(0.485, 0.456, 0.406),
|
|
std=(0.229, 0.224, 0.225),
|
|
max_labels=120
|
|
),
|
|
degrees=self.degrees,
|
|
translate=self.translate,
|
|
scale=self.scale,
|
|
shear=self.shear,
|
|
perspective=self.perspective,
|
|
)
|
|
|
|
self.dataset = dataset
|
|
|
|
if is_distributed:
|
|
batch_size = batch_size // dist.get_world_size()
|
|
sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)
|
|
else:
|
|
sampler = torch.utils.data.RandomSampler(self.dataset)
|
|
|
|
batch_sampler = YoloBatchSampler(
|
|
sampler=sampler,
|
|
batch_size=batch_size,
|
|
drop_last=False,
|
|
input_dimension=self.input_size,
|
|
mosaic=not no_aug
|
|
)
|
|
|
|
dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
|
|
dataloader_kwargs["batch_sampler"] = batch_sampler
|
|
train_loader = DataLoader(self.dataset, **dataloader_kwargs)
|
|
|
|
return train_loader
|