76 lines
2.0 KiB
Python
76 lines
2.0 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
|
|
|
|
import torch
|
|
from torch.nn import Module
|
|
|
|
from yolox.utils import LRScheduler
|
|
|
|
import ast
|
|
import pprint
|
|
from abc import ABCMeta, abstractmethod
|
|
from tabulate import tabulate
|
|
from typing import Dict
|
|
|
|
|
|
class BaseExp(metaclass=ABCMeta):
|
|
"""Basic class for any experiment."""
|
|
|
|
def __init__(self):
|
|
self.seed = None
|
|
self.output_dir = "./YOLOX_outputs"
|
|
self.print_interval = 100
|
|
self.eval_interval = 10
|
|
|
|
@abstractmethod
|
|
def get_model(self) -> Module:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_data_loader(
|
|
self, batch_size: int, is_distributed: bool
|
|
) -> Dict[str, torch.utils.data.DataLoader]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_optimizer(self, batch_size: int) -> torch.optim.Optimizer:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_lr_scheduler(
|
|
self, lr: float, iters_per_epoch: int, **kwargs
|
|
) -> LRScheduler:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_evaluator(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def eval(self, model, evaluator, weights):
|
|
pass
|
|
|
|
def __repr__(self):
|
|
table_header = ["keys", "values"]
|
|
exp_table = [
|
|
(str(k), pprint.pformat(v))
|
|
for k, v in vars(self).items()
|
|
if not k.startswith("_")
|
|
]
|
|
return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
|
|
|
|
def merge(self, cfg_list):
|
|
assert len(cfg_list) % 2 == 0
|
|
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
|
|
# only update value with same key
|
|
if hasattr(self, k):
|
|
src_value = getattr(self, k)
|
|
src_type = type(src_value)
|
|
if src_value is not None and src_type != type(v):
|
|
try:
|
|
v = src_type(v)
|
|
except Exception:
|
|
v = ast.literal_eval(v)
|
|
setattr(self, k, v)
|