Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional SOTA ingredients on Classification Recipe #4493

Merged
merged 27 commits into from
Oct 22, 2021

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Sep 28, 2021

Partially resolves #3995

Add support for the following SOTA ingredients in our classification recipe:

  • EMA per iteration + adjusted decay
  • AdamW support
  • Support custom weight decay for Normalization layers
  • FixRes corrections

Based on the work of @pdollar and @mannatsingh on pycls. Inspired from their work on "Early Convolutions Help Transformers See Better". Also contains improvements from the work of @TouvronHugo on "Fixing the train-test resolution discrepancy".

cc @datumbox @sallysyw

@datumbox datumbox marked this pull request as draft September 28, 2021 11:35
@datumbox datumbox changed the title [WIP] Additional SOTA ingredient on Classification Recipe [WIP] Additional SOTA ingredients on Classification Recipe Sep 28, 2021
@datumbox datumbox marked this pull request as ready for review September 28, 2021 18:20
Copy link
Contributor

@kazhang kazhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall. I would recommend to split into a norm PR and an EMA PR so that we could easily track changes in commit history.

references/classification/utils.py Outdated Show resolved Hide resolved
@datumbox
Copy link
Contributor Author

@kazhang Awesome thanks for confirming.

I need to the changes "stacked" for now to be able to test them on the new recipes but I can certainly split it prior merging. I wanted to get your eyes here early as some of the approaches are adopted from ClassyVision, which you know very well. :)

@prabhat00155
Copy link
Contributor

Thanks @datumbox, looks good to me. Could you share the training logs for the EMA run?

@datumbox
Copy link
Contributor Author

@prabhat00155 Sure thing. Still running stuff. Happy to provide them when I'm done.

Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding comments to improve reviewing:

torchrun --nproc_per_node=8 train.py --model inception_v3
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained
```

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we removed the hardcoding of parameters based on model names, we now need to provide extra parameters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing in my mind, not related to this PR, is that if we can also let users pass kwargs to the models through command line? (in addition to the train.py arguments)
For example, when I train the ViT model, training from scratch and fine-tuning require 2 different heads, in this case I want to configure the representation_size differently, and currently I need to manually change the python defaults to reflect this.
wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will probably need to introduce more parameters to be able to do this. We will do it to enable your work but it's also part of the reason why the ArgumentParser is a poor solution. Hopefully this will be deprecated by the STL work you are preparing!

else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change on interpolation here is non-BC but I consider this a bug rather than a previous feature. On the previous recipe there was a mismatch between the interpolation used for resizing and the one used for AA methods.

@@ -40,16 +38,19 @@ def train_one_epoch(
loss.backward()
optimizer.step()

if model_ema and i % args.model_ema_steps == 0:
model_ema.update_parameters(model)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving EMA updates on per iterration level than on epoch.

model_ema.update_parameters(model)
if epoch < args.lr_warmup_epochs:
# Reset ema buffer to keep copying weights during warmup period
model_ema.n_averaged.fill_(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always copy the weights during warmup.

resize_size, crop_size = sizes[e_type]
interpolation = InterpolationMode.BICUBIC
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
interpolation = InterpolationMode(args.interpolation)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove hardcoding of resize/crops based on model names. Instead use parameters.

)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding AdamW necessary for training ViT.

adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
alpha = 1.0 - args.model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameterize EMA independently from epochs.

lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
if not args.test_only:
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quality of life improvement to avoid the super annoying error messages if you don't define all optimizer params during validation.

if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
else:
evaluate(model, criterion, data_loader_test, device=device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose which model to validate depending on the flag provided.

default=0.9,
help="decay factor for Exponential Moving Average of model parameters(default: 0.9)",
default=0.99998,
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reconfiguring default value of EMA now that we do per iter instead of per epoch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: Is this default value 0.99998 used most often?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good guess for ImageNet, considering the typical batch size for 8 gpus. The reason of changing this so drastically is because we switch from update per epoch to updates every X iters (X=32, configurable).

@datumbox datumbox changed the title [WIP] Additional SOTA ingredients on Classification Recipe Additional SOTA ingredients on Classification Recipe Oct 21, 2021
def train_one_epoch(
model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None
):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
Copy link
Contributor

@yiwen-song yiwen-song Oct 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I'm not a huge fan of passing the whole args to a single method (as it's not clear what are actually needed by this function), but I can see you do this just to reduce the number of arguments.
In the future we might want to add some type hints for all the args used in this script and also some documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, args is passed to reduce the number of parameters (merge 4 to 1). This is used in other places of the script such as, so I just use the same pattern:

def load_data(traindir, valdir, args):

Concerning type hints/documentation, I think you are right. For some reason most of the string args don't define it. I've raised a new #4694 issue to improve it.

@yiwen-song
Copy link
Contributor

Overall LGTM, would you like to share the training logs for the EMA run? and I think we are good to go!

@datumbox
Copy link
Contributor Author

@sallysyw Thanks for the review. Still got a few jobs running (will post everything once I finish and possibly write a blogpost), but I'll send you the logs of the best current model.

@datumbox datumbox merged commit b280c31 into pytorch:main Oct 22, 2021
@datumbox datumbox deleted the references/optimizations branch October 22, 2021 11:31
@github-actions
Copy link

Hey @datumbox!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

facebook-github-bot pushed a commit that referenced this pull request Oct 26, 2021
Summary:
* Update EMA every X iters.

* Adding AdamW optimizer.

* Adjusting EMA decay scheme.

* Support custom weight decay for Normalization layers.

* Fix identation bug.

* Change EMA adjustment.

* Quality of life changes to faciliate testing

* ufmt format

* Fixing imports.

* Adding FixRes improvement.

* Support EMA in store_model_weights.

* Adding interpolation values.

* Change train_crop_size.

* Add interpolation option.

* Removing hardcoded interpolation and sizes from the scripts.

* Fixing linter.

* Incorporating feedback from code review.

Reviewed By: NicolasHug

Differential Revision: D31916313

fbshipit-source-id: 6136c02dd6d511d0f327b5a72c9056a134abc697
cyyever pushed a commit to cyyever/vision that referenced this pull request Nov 16, 2021
* Update EMA every X iters.

* Adding AdamW optimizer.

* Adjusting EMA decay scheme.

* Support custom weight decay for Normalization layers.

* Fix identation bug.

* Change EMA adjustment.

* Quality of life changes to faciliate testing

* ufmt format

* Fixing imports.

* Adding FixRes improvement.

* Support EMA in store_model_weights.

* Adding interpolation values.

* Change train_crop_size.

* Add interpolation option.

* Removing hardcoded interpolation and sizes from the scripts.

* Fixing linter.

* Incorporating feedback from code review.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve the accuracy of Classification models by using SOTA recipes and primitives
5 participants