Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential bug in StochasticWeightAveraging causing incorrect batchnorm statistics #14866

Closed
5 tasks done
zxvix opened this issue Sep 23, 2022 · 1 comment · Fixed by #15113
Closed
5 tasks done

Potential bug in StochasticWeightAveraging causing incorrect batchnorm statistics #14866

zxvix opened this issue Sep 23, 2022 · 1 comment · Fixed by #15113
Labels
bug Something isn't working callback: swa discussion In a discussion stage

Comments

@zxvix
Copy link
Contributor

zxvix commented Sep 23, 2022

First check

  • I'm sure this is a bug.
  • I've added a descriptive title to this bug.
  • I've provided clear instructions on how to reproduce the bug.
  • I've added a code sample.
  • I've provided any other important info that is required.

Bug description

I recently noticed that my model (ResNet trained on a subset of ImageNet) has very bad performance (accuracy ~0) after training with StochasticWeightAveraging. The model performs well (accuracy ~.75) without using StochasticWeightAveraging. After some while, I found that the problem seems to be incorrect batchnorm stats and located the cause in this line:

https://github.com/Lightning-AI/lightning/blob/c77d4a8394a307992b27fe935831b2fc83b5ce6c/src/pytorch_lightning/callbacks/stochastic_weight_avg.py#L299

comparing with the pytorch SWA implementation, which has
https://github.com/pytorch/pytorch/blob/57bffc3a8e4fee0cce31e1ff1f662ccf7b16db57/torch/optim/swa_utils.py#L153

module.momentum = None

Setting module.momentum to None or 0. behaves differently, according to
https://github.com/pytorch/pytorch/blob/b136f3f310aa01a8b3c1e63dc0bfda8fd2234b06/torch/nn/modules/batchnorm.py#L152-L155
It seems that module.momentum = None accumulates an averaged statistics, which is the behavior we want.

After changing to module.momentum = None my model's performance became normal.

How to reproduce the bug

I'm not sure if others can reproduce this bug, but we can first discuss based on the code.

Error messages and logs

No response

Important info

#- Lightning Component: Callback
#- PyTorch Lightning Version: 1.7.7
#- PyTorch Version: 1.12.1
#- Python version: 3.10

More info

No response

@zxvix zxvix added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Sep 23, 2022
@awaelchli awaelchli added callback: swa discussion In a discussion stage and removed needs triage Waiting to be triaged by maintainers labels Sep 24, 2022
@awaelchli
Copy link
Contributor

Intuitively, I think that's right @zxvix. If you are interested in contributing a PR for this, we would really appreciate it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working callback: swa discussion In a discussion stage
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants