From 8455b3708bac4ef038b6c5e6d222640c4d6c41fd Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 1 Apr 2022 08:29:17 +0100 Subject: [PATCH] Add back nesterov momentum --- references/detection/train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/references/detection/train.py b/references/detection/train.py index efe6208012f..758171013e8 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -209,7 +209,13 @@ def main(args): opt_name = args.opt.lower() if opt_name.startswith("sgd"): - optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + optimizer = torch.optim.SGD( + parameters, + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + nesterov="nesterov" in opt_name, + ) elif opt_name == "adamw": optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) else: