-
Notifications
You must be signed in to change notification settings - Fork 391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stochastic Weight Averaging #700
Comments
I didn't know about stochastic weight averaging, thanks a lot. I looked at your code and the PyTorch example and came up with a slightly different implementation based on yours: from torch.optim import swa_utils
class StochasticWeightAveraging(Callback):
def __init__(
self,
swa_utils,
swa_start=10,
verbose=0,
sink=print,
**kwargs # additional arguments to swa_utils.SWALR
):
self.swa_utils = swa_utils
self.swa_start = swa_start
self.verbose = verbose
self.sink = sink
vars(self).update(kwargs)
@property
def kwargs(self):
# These are the parameters that are passed to SWALR.
# Parameters that don't belong there must be excluded.
excluded = {'swa_utils', 'swa_start', 'verbose', 'sink'}
kwargs = {key: val for key, val in vars(self).items()
if not (key in excluded or key.endswith('_'))}
return kwargs
def on_train_begin(self, net, **kwargs):
self.optimizer_swa_ = self.swa_utils.SWALR(net.optimizer_, **self.kwargs)
if not hasattr(net, 'module_swa_'):
net.module_swa_ = self.swa_utils.AveragedModel(net.module_)
def on_epoch_begin(self, net, **kwargs):
if self.verbose and len(net.history) == self.swa_start + 1:
self.sink("Using SWA to update parameters")
def on_epoch_end(self, net, **kwargs):
if len(net.history) >= self.swa_start + 1:
net.module_swa_.update_parameters(net.module_)
self.optimizer_swa_.step()
def on_train_end(self, net, X, y=None, **kwargs):
if self.verbose:
self.sink("Using training data to update batch norm statistics of the SWA model")
loader = net.get_iterator(net.get_dataset(X, y))
self.swa_utils.update_bn(loader, net.module_swa_, device = net.device) Let me explain some of the changes:
What your example is missing compared to the PyTorch example is the use of Below a working example using the callback as implemented above: import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import swa_utils
from skorch import NeuralNetClassifier
from skorch.callbacks import LRScheduler, EpochScoring
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X, y = X.astype(np.float32), y.astype(np.int64)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
SWA_START = 5
MAX_EPOCHS = 100
LR = 0.01
LR_SWA = 0.05
# skorch implementation
class StochasticWeightAveraging(Callback):
...
torch.manual_seed(0)
net = NeuralNetClassifier(
ClassifierModule,
max_epochs=50,
lr=LR,
callbacks=[
LRScheduler(CosineAnnealingLR, T_max=MAX_EPOCHS),
StochasticWeightAveraging(swa_utils, swa_start=SWA_START, verbose=1, swa_lr=LR_SWA),
EpochScoring('accuracy', lower_is_better=False, on_train=True, name='train_acc'),
],
train_split=False,
)
net.fit(X_train, y_train)
test_accuracy = (net.predict(X_test) == y_test).mean()
# PyTorch implementation inspired by linked example
torch.manual_seed(0)
loader = net.get_iterator(net.get_dataset(X, y))
model = ClassifierModule()
optimizer = torch.optim.SGD(model.parameters(), LR)
loss_fn = torch.nn.NLLLoss()
swa_model = swa_utils.AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=MAX_EPOCHS)
swa_scheduler = swa_utils.SWALR(optimizer, swa_lr=LR_SWA)
for epoch in range(MAX_EPOCHS):
losses = []
for input, target in loader:
optimizer.zero_grad()
loss = loss_fn(torch.log(model(input)), target)
loss.backward()
optimizer.step()
losses.append(loss.item())
if epoch == 1 + SWA_START:
print("starting SWA")
if epoch > SWA_START:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()
preds = swa_model(torch.as_tensor(X))
print("epoch: {:>2} | train loss: {:.4f} | train acc: {:.2f} %".format(
epoch, np.mean(losses), 100 * (preds.detach().numpy().argmax(-1) == y).mean()))
swa_utils.update_bn(loader, swa_model)
test_accuracy = (swa_model(torch.as_tensor(X_test)).detach().numpy().argmax(-1) == y_test).mean() The skorch version gets a train loss of 0.548, train accuracy of 0.733, and test accuracy of 0.752. The PyTorch version gets a train loss of 0.460, train accuracy of 0.727, and test accuracy of 0.736. So there seems to be a significant difference in train loss, but I'm not sure where it's coming from. It's not due to the described difference, as introducing the same deviation in the PyTorch code doesn't make a difference. Do you have any idea? |
@WillCastle Did you have opportunity to test this out yet? I believe it might not even be necessary store |
@BenjaminBossan Hi, sorry I have been a little caught up in some job applications. I will have a look at this next week. I am not sure about the discrepancy in training loss, I'll run some test cases and try to work out where it's coming from. The changes you proposed look good, as to multiplying by |
@BenjaminBossan Just had another look and noticed a couple of things. In your Skorch example, you create and fit the object
I believe that the The SWA model is a PyTorch module so I follow it's creation with a conversion to a Skorch model with something like |
Thanks for taking another look @WillCastle
Okay, this makes sense. I would probably allow both possibilities: if int, take it as absolute value, if float, as relative value. This is consistent with how sklearn works in some places, e.g. the
Yes, you're right; to be more precise, it's not the net object, but the Unfortunately, that's not the reason for the discrepancy. I tested both the original
|
``
quite different validation results between pytorch and skorch. Did skorch some weight initialization automatically by default? |
Yes, I need to investigate further, or perhaps someone else can spot a mistake.
No, this is left completely to the user. The module is initialized exactly the same, as well as the data loader. |
|
I assume you have used the code I posted above and now encountered this error. In that case, could you please replace the line: net.module_swa_ = self.swa_utils.AveragedModel(net.module_) by with net._current_init_context('module'):
net.module_swa_ = self.swa_utils.AveragedModel(net.module_) and see if that fixes the issue? |
@BenjaminBossan Yes it fixed the issue. |
PyTorch recently added methods to implement Stochastic Weight Averaging (SWA):
[(https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/)]
This method can improve many models' performance by creating a new model with weights that are averaged over the last few training epochs. Paper here:
[(https://arxiv.org/abs/1803.05407)]
The PyTorch implementation requires calling methods within a training loop but I wanted to use SWA with a Skorch network so I wrote a callback to do it. I wondered if this would be of some use to others.
`
train_loader, skorch_model = ...
class StochasticWeightAveraging(Callback):
`
The text was updated successfully, but these errors were encountered: