129 lines
4.2 KiB
Python
129 lines
4.2 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
# Copyright (c) Megvii, Inc. and its affiliates.
|
|
|
|
from torch.utils.data.dataset import ConcatDataset as torchConcatDataset
|
|
from torch.utils.data.dataset import Dataset as torchDataset
|
|
|
|
import bisect
|
|
from functools import wraps
|
|
|
|
|
|
class ConcatDataset(torchConcatDataset):
|
|
def __init__(self, datasets):
|
|
super(ConcatDataset, self).__init__(datasets)
|
|
if hasattr(self.datasets[0], "input_dim"):
|
|
self._input_dim = self.datasets[0].input_dim
|
|
self.input_dim = self.datasets[0].input_dim
|
|
|
|
def pull_item(self, idx):
|
|
if idx < 0:
|
|
if -idx > len(self):
|
|
raise ValueError(
|
|
"absolute value of index should not exceed dataset length"
|
|
)
|
|
idx = len(self) + idx
|
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
|
if dataset_idx == 0:
|
|
sample_idx = idx
|
|
else:
|
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
|
return self.datasets[dataset_idx].pull_item(sample_idx)
|
|
|
|
|
|
class MixConcatDataset(torchConcatDataset):
|
|
def __init__(self, datasets):
|
|
super(MixConcatDataset, self).__init__(datasets)
|
|
if hasattr(self.datasets[0], "input_dim"):
|
|
self._input_dim = self.datasets[0].input_dim
|
|
self.input_dim = self.datasets[0].input_dim
|
|
|
|
def __getitem__(self, index):
|
|
|
|
if not isinstance(index, int):
|
|
idx = index[1]
|
|
if idx < 0:
|
|
if -idx > len(self):
|
|
raise ValueError(
|
|
"absolute value of index should not exceed dataset length"
|
|
)
|
|
idx = len(self) + idx
|
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
|
if dataset_idx == 0:
|
|
sample_idx = idx
|
|
else:
|
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
|
if not isinstance(index, int):
|
|
index = (index[0], sample_idx, index[2])
|
|
|
|
return self.datasets[dataset_idx][index]
|
|
|
|
|
|
class Dataset(torchDataset):
|
|
""" This class is a subclass of the base :class:`torch.utils.data.Dataset`,
|
|
that enables on the fly resizing of the ``input_dim``.
|
|
|
|
Args:
|
|
input_dimension (tuple): (width,height) tuple with default dimensions of the network
|
|
"""
|
|
|
|
def __init__(self, input_dimension, mosaic=True):
|
|
super().__init__()
|
|
self.__input_dim = input_dimension[:2]
|
|
self.enable_mosaic = mosaic
|
|
|
|
@property
|
|
def input_dim(self):
|
|
"""
|
|
Dimension that can be used by transforms to set the correct image size, etc.
|
|
This allows transforms to have a single source of truth
|
|
for the input dimension of the network.
|
|
|
|
Return:
|
|
list: Tuple containing the current width,height
|
|
"""
|
|
if hasattr(self, "_input_dim"):
|
|
return self._input_dim
|
|
return self.__input_dim
|
|
|
|
@staticmethod
|
|
def resize_getitem(getitem_fn):
|
|
"""
|
|
Decorator method that needs to be used around the ``__getitem__`` method. |br|
|
|
This decorator enables the on the fly resizing of
|
|
the ``input_dim`` with our :class:`~lightnet.data.DataLoader` class.
|
|
|
|
Example:
|
|
>>> class CustomSet(ln.data.Dataset):
|
|
... def __len__(self):
|
|
... return 10
|
|
... @ln.data.Dataset.resize_getitem
|
|
... def __getitem__(self, index):
|
|
... # Should return (image, anno) but here we return input_dim
|
|
... return self.input_dim
|
|
>>> data = CustomSet((200,200))
|
|
>>> data[0]
|
|
(200, 200)
|
|
>>> data[(480,320), 0]
|
|
(480, 320)
|
|
"""
|
|
|
|
@wraps(getitem_fn)
|
|
def wrapper(self, index):
|
|
if not isinstance(index, int):
|
|
has_dim = True
|
|
self._input_dim = index[0]
|
|
self.enable_mosaic = index[2]
|
|
index = index[1]
|
|
else:
|
|
has_dim = False
|
|
|
|
ret_val = getitem_fn(self, index)
|
|
|
|
if has_dim:
|
|
del self._input_dim
|
|
|
|
return ret_val
|
|
|
|
return wrapper
|