Skip to content

Commit

Permalink
Support custom weight decay for Normalization layers.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Sep 28, 2021
1 parent 33a90f7 commit 5724c1c
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 10 deletions.
20 changes: 14 additions & 6 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,22 @@ def main(args):

criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

if args.norm_weight_decay is None:
parameters = model.parameters()
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]

opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name)
optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
optimizer = torch.optim.RMSprop(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
eps=0.0316, alpha=0.9)
elif opt_name == 'adamw':
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")

Expand Down Expand Up @@ -326,6 +332,8 @@ def get_args_parser(add_help=True):
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--norm-weight-decay', default=None, type=float,
help='weight decay for Normalization layers (default: None, same value as --wd)')
parser.add_argument('--label-smoothing', default=0.0, type=float,
help='label smoothing (default: 0.0)',
dest='label_smoothing')
Expand Down
14 changes: 12 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from PIL import Image
import torch
from functools import lru_cache
from torch import Tensor
from torch import nn, Tensor
from torch.autograd import gradcheck
from torch.nn.modules.utils import _pair
from torchvision import ops
from torchvision import models, ops
from typing import Tuple


Expand Down Expand Up @@ -1062,5 +1062,15 @@ def test_stochastic_depth(self, mode, p):
assert p_value > 0.0001


class TestUtils:
@pytest.mark.parametrize('norm_layer', [None, nn.BatchNorm2d, nn.LayerNorm])
def test_split_normalization_params(self, norm_layer):
model = models.mobilenet_v3_large(norm_layer=norm_layer)
params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])

assert len(params[0]) == 92
assert len(params[1]) == 82


if __name__ == '__main__':
pytest.main([__file__])
28 changes: 26 additions & 2 deletions torchvision/ops/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import Tensor
from typing import List, Union
from torch import nn, Tensor
from typing import List, Optional, Tuple, Union


def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
Expand Down Expand Up @@ -34,3 +34,27 @@ def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
else:
assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]'
return


def split_normalization_params(model: nn.Module,
norm_classes: Optional[List[type]] = None) -> Tuple[List[Tensor], List[Tensor]]:
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
if not norm_classes:
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]

for t in norm_classes:
if not issubclass(t, nn.Module):
raise ValueError(f"Class {t} is not a subclass of nn.Module.")

classes = tuple(norm_classes)

norm_params = []
other_params = []
for module in model.modules():
if next(module.children(), None):
other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
elif isinstance(module, classes):
norm_params.extend(p for p in module.parameters() if p.requires_grad)
else:
other_params.extend(p for p in module.parameters() if p.requires_grad)
return norm_params, other_params

0 comments on commit 5724c1c

Please sign in to comment.