Skip to content

Commit

Permalink
Add back nesterov momentum
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Apr 1, 2022
1 parent ff6d641 commit 8455b37
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,13 @@ def main(args):

opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
optimizer = torch.optim.SGD(
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
Expand Down

0 comments on commit 8455b37

Please sign in to comment.