You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've provided clear instructions on how to reproduce the bug.
I've added a code sample.
I've provided any other important info that is required.
Bug description
I recently noticed that my model (ResNet trained on a subset of ImageNet) has very bad performance (accuracy ~0) after training with StochasticWeightAveraging. The model performs well (accuracy ~.75) without using StochasticWeightAveraging. After some while, I found that the problem seems to be incorrect batchnorm stats and located the cause in this line:
First check
Bug description
I recently noticed that my model (ResNet trained on a subset of ImageNet) has very bad performance (accuracy ~0) after training with StochasticWeightAveraging. The model performs well (accuracy ~.75) without using StochasticWeightAveraging. After some while, I found that the problem seems to be incorrect batchnorm stats and located the cause in this line:
https://github.com/Lightning-AI/lightning/blob/c77d4a8394a307992b27fe935831b2fc83b5ce6c/src/pytorch_lightning/callbacks/stochastic_weight_avg.py#L299
comparing with the pytorch SWA implementation, which has
https://github.com/pytorch/pytorch/blob/57bffc3a8e4fee0cce31e1ff1f662ccf7b16db57/torch/optim/swa_utils.py#L153
Setting
module.momentum
toNone
or0.
behaves differently, according tohttps://github.com/pytorch/pytorch/blob/b136f3f310aa01a8b3c1e63dc0bfda8fd2234b06/torch/nn/modules/batchnorm.py#L152-L155
It seems that
module.momentum = None
accumulates an averaged statistics, which is the behavior we want.After changing to
module.momentum = None
my model's performance became normal.How to reproduce the bug
Error messages and logs
No response
Important info
More info
No response
The text was updated successfully, but these errors were encountered: