Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
else:
loss = criterion(output, target)
loss.backward()

if args.clip_grad_norm is not None:
nn.utils.clip_grad_norm_(utils.master_params(optimizer), args.clip_grad_norm)

optimizer.step()

if model_ema and i % args.model_ema_steps == 0:
Expand Down Expand Up @@ -472,6 +476,7 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")

# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
Expand Down
19 changes: 19 additions & 0 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,22 @@ def reduce_across_processes(val):
dist.barrier()
dist.all_reduce(t)
return t


try:
import apex

apex_available = True
except ImportError:
apex_available = False


def master_params(optimizer):
"""Generator to iterate over all parameters in the optimizer param_groups."""

if apex_available:
yield from apex.amp.master_params(optimizer)
else:
for group in optimizer.param_groups:
for p in group["params"]:
yield p