Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning rate finder fails for optimizer with internal state #3340

Closed
tsteffek opened this issue Sep 3, 2020 · 3 comments · Fixed by #3897
Closed

Learning rate finder fails for optimizer with internal state #3340

tsteffek opened this issue Sep 3, 2020 · 3 comments · Fixed by #3897
Assignees
Labels
bug Something isn't working help wanted Open to be worked on waiting on author Waiting on user action, correction, or update
Milestone

Comments

@tsteffek
Copy link

tsteffek commented Sep 3, 2020

🐛 Bug

DISCLAIMER: I think this is solved in the new Trainer.tune() approach. Still writing this to raise awareness and in case someone looks for this. Apparently it's not.

When using the auto_lr_find tag, I noticed it fails when used with torch.optim.Adagrad. Upon closer inspection, I realized that it's due to the internal state of the optimizer.
When the optimizer is first initialized, it populates a state attribute with tensors like the model parameters (in Trainer.fit(), line 1011, v0.9.0). However, the model parameters will be moved to the correct device afterwards (in Trainer.fit(), line 1030+ depending on backend, v0.9.0).
This results in an error, since the state is still on cpu, while the model has been moved to cuda.

Error message in here
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]Traceback (most recent call last):
  File "C:/Users/anarc/AppData/Roaming/JetBrains/PyCharm2020.2/scratches/torch/scratch.py", line 41, in <module>
    trainer.fit(model)
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\trainer\states.py", line 48, in wrapped_fn
    result = fn(self, *args, **kwargs)
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1011, in fit
    self._run_lr_finder_internally(model)
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\trainer\lr_finder.py", line 72, in _run_lr_finder_internally
    lr_finder = self.lr_find(model)
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\trainer\lr_finder.py", line 200, in lr_find
    self.fit(model,
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\trainer\states.py", line 48, in wrapped_fn
    result = fn(self, *args, **kwargs)
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1073, in fit
    results = self.accelerator_backend.train(model)
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\accelerators\gpu_backend.py", line 51, in train
    results = self.trainer.run_pretrain_routine(model)
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1239, in run_pretrain_routine
    self.train()
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 394, in train
    self.run_training_epoch()
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 491, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx)
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 890, in run_training_batch
    grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer)
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 951, in run_batch_backward_pass
    self.call_optimizer_step(optimizer, opt_idx, batch_idx, split_batch)
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 988, in call_optimizer_step
    model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure,
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\pytorch_lightning\core\lightning.py", line 1160, in optimizer_step
    optimizer.step()
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\torch\optim\lr_scheduler.py", line 67, in wrapper
    return wrapped(*args, **kwargs)
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\torch\autograd\grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "C:\tools\miniconda3\envs\ElmoOnTorch\lib\site-packages\torch\optim\adagrad.py", line 98, in step
    state['sum'].addcmul_(grad, grad, value=1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]

Process finished with exit code 1

Code sample

I used the doc example and fitted it with a random dataset for this minimal example.

import torch
from pytorch_lightning import LightningModule, Trainer
from torch.nn import functional as F
from torch.utils.data import DataLoader, IterableDataset


class RandomDataset(IterableDataset):
    def __iter__(self):
        while True:
            yield torch.randn(10), 1


class LitModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(10, 5)
        self.lr = 0.001

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}

        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adagrad(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        return DataLoader(RandomDataset())


if __name__ == '__main__':
    model = LitModel()
    trainer = Trainer(gpus=1, auto_lr_find=True)
    trainer.fit(model)

Expected behavior

The Trainer should first correctly setup the model before initializing the optimizers when using the learn rate finder.

Environment

  • CUDA:
    - GPU:
    - GeForce GTX 1050
    - available: True
    - version: 10.2
  • Packages:
    - numpy: 1.19.1
    - pyTorch_debug: False
    - pyTorch_version: 1.6.0
    - pytorch-lightning: 0.9.0
    - tensorboard: 2.2.0
    - tqdm: 4.48.2
  • System:
    - OS: Windows
    - architecture:
    - 64bit
    - WindowsPE
    - processor: Intel64 Family 6 Model 158 Stepping 9, GenuineIntel
    - python: 3.8.5
    - version: 10.0.18362
@tsteffek tsteffek added bug Something isn't working help wanted Open to be worked on labels Sep 3, 2020
@awaelchli
Copy link
Contributor

DISCLAIMER: I think this is solved in the new Trainer.tune() approach.

I just checked, the problem remains when we call .tune().

@Borda
Copy link
Member

Borda commented Oct 5, 2020

DISCLAIMER: I think this is solved in the new Trainer.tune() approach.

I just checked, the problem remains when we call .tune().

@awaelchli mind share what was the failing call?
I have tried this and everything was fine...

if __name__ == '__main__':
    model = LitModel()
    trainer = Trainer(gpus=0, max_steps=50)
    lrfinder = trainer.tuner.lr_find(model, mode='linear')
    model.learning_rate = lrfinder.suggestion()
    trainer.tune(model)
    trainer.fit(model)

@Borda Borda added the waiting on author Waiting on user action, correction, or update label Oct 5, 2020
@awaelchli
Copy link
Contributor

import torch
from pytorch_lightning import LightningModule, Trainer
from torch.nn import functional as F
from torch.utils.data import DataLoader, IterableDataset


class RandomDataset(IterableDataset):
    def __iter__(self):
        while True:
            yield torch.randn(10), 1


class LitModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(10, 5)
        self.lr = 0.001

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}

        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adagrad(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        return DataLoader(RandomDataset())


if __name__ == '__main__':
    model = LitModel()
    trainer = Trainer(gpus=1, auto_lr_find=True)
    trainer.tune(model)  # <---- added this, it is now required in new api
    trainer.fit(model)

This is the same code as originally posted but with the additional trainer.tune(model)
The error message is the same as in the description above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants