From 2867310a0d82cef52a58e51da0a0ce816ce2e640 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 18 Sep 2021 13:30:07 +0100 Subject: [PATCH 1/8] Adding ExponentialLR and LinearLR --- references/classification/train.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 90abdb0b47e..e3440521ddb 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -214,15 +214,25 @@ def main(args): elif args.lr_scheduler == 'cosineannealinglr': main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - args.lr_warmup_epochs) + elif args.lr_scheduler == 'exponentiallr': + main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) else: - raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR and CosineAnnealingLR " + raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR " "are supported.".format(args.lr_scheduler)) if args.lr_warmup_epochs > 0: + if args.lr_warmup_method == 'linear': + warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, + total_iters=args.lr_warmup_epochs) + elif args.lr_warmup_method == 'constant': + warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, + total_iters=args.lr_warmup_epochs) + else: + raise RuntimeError("Invalid warmup lr method '{}'. Only linear and constant " + "are supported.".format(args.lr_warmup_method)) lr_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, - schedulers=[torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, - total_iters=args.lr_warmup_epochs), main_lr_scheduler], + schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs] ) else: @@ -307,6 +317,8 @@ def get_args_parser(add_help=True): parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)') parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)') parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') + parser.add_argument('--lr-warmup-method', default="constant", type=str, + help='the warmup method (default: constant)') parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') From ae7d7e92c5a97bd0b18fdc42b0090dae5ad8dd37 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 18 Sep 2021 14:32:32 +0100 Subject: [PATCH 2/8] Fix arg type of --lr-warmup-decay --- references/classification/train.py | 2 +- references/segmentation/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index e3440521ddb..34bed3f14ce 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -319,7 +319,7 @@ def get_args_parser(add_help=True): parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') parser.add_argument('--lr-warmup-method', default="constant", type=str, help='the warmup method (default: constant)') - parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr') + parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 476058ce0c0..83277de9c2c 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -220,7 +220,7 @@ def get_args_parser(add_help=True): dest='weight_decay') parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)') - parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr') + parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') From 53080980042827e1c0c918e1a3c61c0c4c550d1b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 19 Sep 2021 23:47:16 +0100 Subject: [PATCH 3/8] Adding support of Zero gamma BN and SGD with nesterov. --- references/classification/train.py | 15 +++++++++++++-- references/classification/utils.py | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 34bed3f14ce..587ea6ad48d 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -188,15 +188,20 @@ def main(args): print("Creating model") model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) model.to(device) + + if not args.pretrained and args.zero_gamma_bn: + utils.bn_reinitialization(model, gamma=0.0) + if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) opt_name = args.opt.lower() - if opt_name == 'sgd': + if opt_name.startswith("sgd"): optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, + nesterov="nesterov" in opt_name) elif opt_name == 'rmsprop': optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) @@ -339,6 +344,12 @@ def get_args_parser(add_help=True): help="Use sync batch norm", action="store_true", ) + parser.add_argument( + "--zero-gamma-bn", + dest="zero_gamma_bn", + help="Initialize the gamma of batch norm with zero", + action="store_true", + ) parser.add_argument( "--test-only", dest="test_only", diff --git a/references/classification/utils.py b/references/classification/utils.py index fad607636e5..2e9b5a5302c 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -184,6 +184,24 @@ def update_parameters(self, model): self.n_averaged += 1 +def bn_reinitialization(model: torch.nn.Module, gamma: float = 1.0, beta: float = 0.0): + """ + This method overwrites the default gamma and beta initial values for BatchNorm layers. + It can be used to perform Zero gamma initialization as described at + `"Bag of Tricks for Image Classification with Convolutional Neural Networks" + `_. + + Args: + model (nn.Module): The model on which we perform the BatchNorm re-initialization. + gamma (float): The gamma initial value. + beta (float): The beta initial value. + """ + for module in model.modules(): + if isinstance(module, torch.nn._BatchNorm): + module.weight.fill_(gamma) + module.bias.fill_(beta) + + def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): From 12357aa769db7a640c8fff9c182f9e6ece1acaa2 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 19 Sep 2021 23:56:40 +0100 Subject: [PATCH 4/8] Fix --lr-warmup-decay for video_classification. --- references/video_classification/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 353e0d6d1f7..0eefbc0b282 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -296,7 +296,7 @@ def parse_args(): parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='the number of epochs to warmup (default: 10)') parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)') - parser.add_argument('--lr-warmup-decay', default=0.001, type=int, help='the decay for lr') + parser.add_argument('--lr-warmup-decay', default=0.001, type=float, help='the decay for lr') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') From e51ddb0a1dd02d33deff6a529093f4d3a4dab9e6 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 20 Sep 2021 00:42:09 +0100 Subject: [PATCH 5/8] Update bn_reinit --- references/classification/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/references/classification/utils.py b/references/classification/utils.py index 2e9b5a5302c..62c4dcb0460 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -197,9 +197,9 @@ def bn_reinitialization(model: torch.nn.Module, gamma: float = 1.0, beta: float beta (float): The beta initial value. """ for module in model.modules(): - if isinstance(module, torch.nn._BatchNorm): - module.weight.fill_(gamma) - module.bias.fill_(beta) + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + torch.nn.init.constant_(module.weight, gamma) + torch.nn.init.constant_(module.bias, beta) def accuracy(output, target, topk=(1,)): From 1b2dc9565aebf05d3f54c71cb23fa787e185bfb7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 20 Sep 2021 10:57:02 +0100 Subject: [PATCH 6/8] Fix pre-existing bug on num_classes of model --- references/classification/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index 587ea6ad48d..e9bf3f7c885 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -186,7 +186,7 @@ def main(args): sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model") - model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) + model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) model.to(device) if not args.pretrained and args.zero_gamma_bn: From 29615b7eab317d1db1d11e121f9f975d89198d6d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 20 Sep 2021 11:11:24 +0100 Subject: [PATCH 7/8] Remove zero gamma. --- references/classification/train.py | 9 --------- references/classification/utils.py | 18 ------------------ 2 files changed, 27 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index e9bf3f7c885..6690793989d 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -189,9 +189,6 @@ def main(args): model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) model.to(device) - if not args.pretrained and args.zero_gamma_bn: - utils.bn_reinitialization(model, gamma=0.0) - if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -344,12 +341,6 @@ def get_args_parser(add_help=True): help="Use sync batch norm", action="store_true", ) - parser.add_argument( - "--zero-gamma-bn", - dest="zero_gamma_bn", - help="Initialize the gamma of batch norm with zero", - action="store_true", - ) parser.add_argument( "--test-only", dest="test_only", diff --git a/references/classification/utils.py b/references/classification/utils.py index 62c4dcb0460..fad607636e5 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -184,24 +184,6 @@ def update_parameters(self, model): self.n_averaged += 1 -def bn_reinitialization(model: torch.nn.Module, gamma: float = 1.0, beta: float = 0.0): - """ - This method overwrites the default gamma and beta initial values for BatchNorm layers. - It can be used to perform Zero gamma initialization as described at - `"Bag of Tricks for Image Classification with Convolutional Neural Networks" - `_. - - Args: - model (nn.Module): The model on which we perform the BatchNorm re-initialization. - gamma (float): The gamma initial value. - beta (float): The beta initial value. - """ - for module in model.modules(): - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): - torch.nn.init.constant_(module.weight, gamma) - torch.nn.init.constant_(module.bias, beta) - - def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): From cb7905c5c6202b0fd2a93aeec1e6932563ec94b4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 21 Sep 2021 19:59:30 +0100 Subject: [PATCH 8/8] Use fstrings. --- references/classification/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 6690793989d..48ab75bc2c1 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -230,8 +230,8 @@ def main(args): warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs) else: - raise RuntimeError("Invalid warmup lr method '{}'. Only linear and constant " - "are supported.".format(args.lr_warmup_method)) + raise RuntimeError(f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant " + "are supported.") lr_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler],