Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion timm/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .transforms import *
from .loader import create_loader
from .transforms_factory import create_transform
from .mixup import mixup_batch, FastCollateMixup
from .mixup import Mixup, FastCollateMixup
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .real_labels import RealLabelsImagenet
2 changes: 1 addition & 1 deletion timm/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __getitem__(self, index):
return img, target

def __len__(self):
return len(self.imgs)
return len(self.samples)

def filenames(self, indices=[], basename=False):
if indices:
Expand Down
247 changes: 222 additions & 25 deletions timm/data/mixup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
""" Mixup
Paper: `mixup: Beyond Empirical Risk Minimization` - https://arxiv.org/abs/1710.09412
""" Mixup and Cutmix

Papers:
mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)

CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)

Code Reference:
CutMix: https://github.com/clovaai/CutMix-PyTorch

Hacked together by / Copyright 2020 Ross Wightman
"""
Expand All @@ -17,40 +24,230 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
on_value = 1. - smoothing + off_value
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
return lam*y1 + (1. - lam)*y2
return y1 * lam + y2 * (1. - lam)


def rand_bbox(img_shape, lam, margin=0., count=None):
""" Standard CutMix bounding-box
Generates a random square bbox based on lambda value. This impl includes
support for enforcing a border margin as percent of bbox dimensions.

Args:
img_shape (tuple): Image shape as tuple
lam (float): Cutmix lambda value
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
count (int): Number of bbox to generate
"""
ratio = np.sqrt(1 - lam)
img_h, img_w = img_shape[-2:]
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
yl = np.clip(cy - cut_h // 2, 0, img_h)
yh = np.clip(cy + cut_h // 2, 0, img_h)
xl = np.clip(cx - cut_w // 2, 0, img_w)
xh = np.clip(cx + cut_w // 2, 0, img_w)
return yl, yh, xl, xh


def rand_bbox_minmax(img_shape, minmax, count=None):
""" Min-Max CutMix bounding-box
Inspired by Darknet cutmix impl, generates a random rectangular bbox
based on min/max percent values applied to each dimension of the input image.

Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.

def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
lam = 1.
if not disable:
lam = np.random.beta(alpha, alpha)
input = input.mul(lam).add_(1 - lam, input.flip(0))
target = mixup_target(target, num_classes, lam, smoothing)
return input, target
Args:
img_shape (tuple): Image shape as tuple
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
count (int): Number of bbox to generate
"""
assert len(minmax) == 2
img_h, img_w = img_shape[-2:]
cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
yl = np.random.randint(0, img_h - cut_h, size=count)
xl = np.random.randint(0, img_w - cut_w, size=count)
yu = yl + cut_h
xu = xl + cut_w
return yl, yu, xl, xu


class FastCollateMixup:
def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
""" Generate bbox and apply lambda correction.
"""
if ratio_minmax is not None:
yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
else:
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
if correct_lam or ratio_minmax is not None:
bbox_area = (yu - yl) * (xu - xl)
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
return (yl, yu, xl, xu), lam

def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):

class Mixup:
""" Mixup/Cutmix that applies different params to each element or whole batch

Args:
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
prob (float): probability of applying mixup or cutmix per batch or element
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
label_smoothing (float): apply label smoothing to the mixed target tensor
num_classes (int): number of classes for target
"""
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
elementwise=False, correct_lam=True, label_smoothing=0.1, num_classes=1000):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax
if self.cutmix_minmax is not None:
assert len(self.cutmix_minmax) == 2
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
self.cutmix_alpha = 1.0
self.mix_prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing
self.num_classes = num_classes
self.mixup_enabled = True
self.elementwise = elementwise
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)

def __call__(self, batch):
batch_size = len(batch)
lam = 1.
def _params_per_elem(self, batch_size):
lam = np.ones(batch_size, dtype=np.float32)
use_cutmix = np.zeros(batch_size, dtype=np.bool)
if self.mixup_enabled:
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand(batch_size) < self.switch_prob
lam_mix = np.where(
use_cutmix,
np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
elif self.cutmix_alpha > 0.:
use_cutmix = np.ones(batch_size, dtype=np.bool)
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
return lam, use_cutmix

target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
def _params_per_batch(self):
lam = 1.
use_cutmix = False
if self.mixup_enabled and np.random.rand() < self.mix_prob:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand() < self.switch_prob
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.cutmix_alpha > 0.:
use_cutmix = True
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = float(lam_mix)
return lam, use_cutmix

tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
def _mix_elem(self, x):
batch_size = len(x)
lam_batch, use_cutmix = self._params_per_elem(batch_size)
x_orig = x.clone() # need to keep an unmodified original for mixing source
for i in range(batch_size):
mixed = batch[i][0].astype(np.float32) * lam + \
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
np.round(mixed, out=mixed)
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
j = batch_size - i - 1
lam = lam_batch[i]
if lam != 1.:
if use_cutmix[i]:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)

def _mix_batch(self, x):
lam, use_cutmix = self._params_per_batch()
if lam == 1.:
return 1.
if use_cutmix:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
else:
x_flipped = x.flip(0).mul_(1. - lam)
x.mul_(lam).add_(x_flipped)
return lam

def __call__(self, x, target):
assert len(x) % 2 == 0, 'Batch size should be even when using this'
lam = self._mix_elem(x) if self.elementwise else self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
return x, target


class FastCollateMixup(Mixup):
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch

A Mixup impl that's performed while collating the batches.
"""

def _mix_elem_collate(self, output, batch):
batch_size = len(batch)
lam_batch, use_cutmix = self._params_per_elem(batch_size)
for i in range(batch_size):
j = batch_size - i - 1
lam = lam_batch[i]
mixed = batch[i][0]
if lam != 1.:
if use_cutmix[i]:
mixed = mixed.copy()
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
lam_batch[i] = lam
np.round(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
return torch.tensor(lam_batch).unsqueeze(1)

def _mix_batch_collate(self, output, batch):
batch_size = len(batch)
lam, use_cutmix = self._params_per_batch()
if use_cutmix:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
for i in range(batch_size):
j = batch_size - i - 1
mixed = batch[i][0]
if lam != 1.:
if use_cutmix:
mixed = mixed.copy() # don't want to modify the original while iterating
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.round(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
return lam

def __call__(self, batch, _=None):
batch_size = len(batch)
assert batch_size % 2 == 0, 'Batch size should be even when using this'
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
if self.elementwise:
lam = self._mix_elem_collate(output, batch)
else:
lam = self._mix_batch_collate(output, batch)
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
return output, target

return tensor, target
55 changes: 35 additions & 20 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch.nn.parallel import DistributedDataParallel as DDP
has_apex = False

from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
Expand Down Expand Up @@ -156,7 +156,17 @@
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.0,
help='Mixup alpha, mixup enabled if > 0. (default: 0.)')
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=0.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-elem', action='store_true', default=False,
help='Apply mixup/cutmix params uniquely per batch element instead of per batch.')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
Expand Down Expand Up @@ -388,9 +398,18 @@ def main():
dataset_train = Dataset(train_dir)

collate_fn = None
if args.prefetcher and args.mixup > 0:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, elementwise=args.mixup_elem,
label_smoothing=args.smoothing, num_classes=args.num_classes)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)

if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
Expand Down Expand Up @@ -452,17 +471,14 @@ def main():
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.mixup > 0.:
# smoothing is handled with mixup label transform
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = train_loss_fn
validate_loss_fn = nn.CrossEntropyLoss().cuda()

eval_metric = args.eval_metric
best_metric = None
Expand Down Expand Up @@ -490,7 +506,7 @@ def main():
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
use_amp=use_amp, model_ema=model_ema)
use_amp=use_amp, model_ema=model_ema, mixup_fn=mixup_fn)

if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
Expand Down Expand Up @@ -530,11 +546,13 @@ def main():

def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, mixup_fn=None):

if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False

batch_time_m = AverageMeter()
data_time_m = AverageMeter()
Expand All @@ -550,11 +568,8 @@ def train_epoch(
data_time_m.update(time.time() - end)
if not args.prefetcher:
input, target = input.cuda(), target.cuda()
if args.mixup > 0.:
input, target = mixup_batch(
input, target,
alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing,
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch)
if mixup_fn is not None:
input, target = mixup_fn(input, target)

output = model(input)

Expand Down