diff --git a/references/classification/train.py b/references/classification/train.py index 36d27fa40e6..af018797561 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -197,16 +197,22 @@ def main(args): criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + if args.norm_weight_decay is None: + parameters = model.parameters() + else: + param_groups = torchvision.ops._utils.split_normalization_params(model) + wd_groups = [args.norm_weight_decay, args.weight_decay] + parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p] + opt_name = args.opt.lower() if opt_name.startswith("sgd"): - optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, - nesterov="nesterov" in opt_name) + optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, + nesterov="nesterov" in opt_name) elif opt_name == 'rmsprop': - optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, - weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) + optimizer = torch.optim.RMSprop(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, + eps=0.0316, alpha=0.9) elif opt_name == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) else: raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.") @@ -326,6 +332,8 @@ def get_args_parser(add_help=True): parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') + parser.add_argument('--norm-weight-decay', default=None, type=float, + help='weight decay for Normalization layers (default: None, same value as --wd)') parser.add_argument('--label-smoothing', default=0.0, type=float, help='label smoothing (default: 0.0)', dest='label_smoothing') diff --git a/test/test_ops.py b/test/test_ops.py index 8ab23f3ff64..36682adc6d3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -9,10 +9,10 @@ from PIL import Image import torch from functools import lru_cache -from torch import Tensor +from torch import nn, Tensor from torch.autograd import gradcheck from torch.nn.modules.utils import _pair -from torchvision import ops +from torchvision import models, ops from typing import Tuple @@ -1062,5 +1062,15 @@ def test_stochastic_depth(self, mode, p): assert p_value > 0.0001 +class TestUtils: + @pytest.mark.parametrize('norm_layer', [None, nn.BatchNorm2d, nn.LayerNorm]) + def test_split_normalization_params(self, norm_layer): + model = models.mobilenet_v3_large(norm_layer=norm_layer) + params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer]) + + assert len(params[0]) == 92 + assert len(params[1]) == 82 + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 7cc6367a7a4..2cf47f31c72 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -1,6 +1,6 @@ import torch -from torch import Tensor -from typing import List, Union +from torch import nn, Tensor +from typing import List, Optional, Tuple, Union def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor: @@ -34,3 +34,27 @@ def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]): else: assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]' return + + +def split_normalization_params(model: nn.Module, + norm_classes: Optional[List[type]] = None) -> Tuple[List[Tensor], List[Tensor]]: + # Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501 + if not norm_classes: + norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm] + + for t in norm_classes: + if not issubclass(t, nn.Module): + raise ValueError(f"Class {t} is not a subclass of nn.Module.") + + classes = tuple(norm_classes) + + norm_params = [] + other_params = [] + for module in model.modules(): + if next(module.children(), None): + other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad) + elif isinstance(module, classes): + norm_params.extend(p for p in module.parameters() if p.requires_grad) + else: + other_params.extend(p for p in module.parameters() if p.requires_grad) + return norm_params, other_params