diff --git a/references/classification/train.py b/references/classification/train.py index 4d93f073454..40c174c35d7 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -260,10 +260,10 @@ def main(args): # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at: # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123 # - # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size * EMA_steps) - # We consider constant = (Dataset_size / n_GPUs) for a given dataset/setup and ommit it. Thus: - # adjust = 1 / total_ema_updates ~= batch_size * EMA_steps / epochs - adjust = args.batch_size * args.model_ema_steps / args.epochs + # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps) + # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus: + # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs + adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs alpha = 1.0 - args.model_ema_decay alpha = min(1.0, alpha * adjust) model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha) @@ -397,8 +397,8 @@ def get_args_parser(add_help=True): '--model-ema-steps', type=int, default=32, help='the number of iterations that controls how often to update the EMA model (default: 32)') parser.add_argument( - '--model-ema-decay', type=float, default=0.99999, - help='decay factor for Exponential Moving Average of model parameters (default: 0.99999)') + '--model-ema-decay', type=float, default=0.99998, + help='decay factor for Exponential Moving Average of model parameters (default: 0.99998)') return parser