Skip to content

Commit

Permalink
Adjusting EMA decay scheme.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Sep 28, 2021
1 parent 02b4d42 commit 33a90f7
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
print_freq, apex=False, model_ema=None):
print_freq, args, apex=False, model_ema=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
Expand All @@ -41,16 +41,19 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
loss.backward()
optimizer.step()

if model_ema and i % args.model_ema_steps == 0:
model_ema.update_parameters(model)
if epoch < args.lr_warmup_epochs:
# Reset ema buffer to keep copying weights during warmup period
model_ema.n_averaged.fill_(0)

acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))

if model_ema and i % model_ema.update_steps == 0:
model_ema.update_parameters(model)


def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=''):
model.eval()
Expand Down Expand Up @@ -249,8 +252,16 @@ def main(args):

model_ema = None
if args.model_ema:
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)
model_ema.update_steps = args.model_ema_steps
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
#
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size * EMA_steps)
# We consider constant = (Dataset_size / n_GPUs) for a given dataset/setup and ommit it. Thus:
# adjust = 1 / total_ema_updates ~= batch_size * EMA_steps / epochs
adjust = args.batch_size * args.model_ema_steps / args.epochs
alpha = 1.0 - args.model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)

if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
Expand Down Expand Up @@ -376,11 +387,11 @@ def get_args_parser(add_help=True):
'--model-ema', action='store_true',
help='enable tracking Exponential Moving Average of model parameters')
parser.add_argument(
'--model-ema-steps', type=float, default=32,
'--model-ema-steps', type=int, default=32,
help='the number of iterations that controls how often to update the EMA model (default: 32)')
parser.add_argument(
'--model-ema-decay', type=float, default=0.9,
help='decay factor for Exponential Moving Average of model parameters (default: 0.9)')
'--model-ema-decay', type=float, default=0.99999,
help='decay factor for Exponential Moving Average of model parameters (default: 0.99999)')

return parser

Expand Down

0 comments on commit 33a90f7

Please sign in to comment.