Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lr scheduler doesn't follow optimizer's frequency [in an edge case] #12169

Closed
akihironitta opened this issue Mar 1, 2022 · 2 comments · Fixed by #16539
Closed

lr scheduler doesn't follow optimizer's frequency [in an edge case] #12169

akihironitta opened this issue Mar 1, 2022 · 2 comments · Fixed by #16539
Labels
bug Something isn't working lr scheduler priority: 2 Low priority task

Comments

@akihironitta
Copy link
Contributor

akihironitta commented Mar 1, 2022

🐛 Bug

This is an unlikely case, but UserWarning: Detected call of lr_scheduler.step() before optimizer.step() gets raised in the following condition:

  • using two optimizers and (at least one) lr schedulers
  • AND the first optimizer's frequency is more than or equal to the number of batches in an epoch (so the second optimizer's step is never called)

To Reproduce

Running the script below with any values that satisfy LIMIT_TRAIN_BATCHES<=FREQUENCY1.

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer

# raises the warning if LIMIT_TRAIN_BATCHES<=FREQUENCY1
LIMIT_TRAIN_BATCHES = 4
FREQUENCY1 = 5
FREQUENCY2 = 3

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx, optimizer_idx):
        return self(batch).sum()

    def configure_optimizers(self):
        optimizer1 = torch.optim.Adam(self.parameters(), lr=0.01)
        optimizer2 = torch.optim.Adam(self.parameters(), lr=0.01)
        scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=5)
        scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer2, T_max=2
        )
        lr_scheduler_config_1 = {"scheduler": scheduler1, "interval": "epoch"}
        lr_scheduler_config_2 = {"scheduler": scheduler2, "interval": "epoch"}
        return [
            {
                "optimizer": optimizer1,
                "frequency": FREQUENCY1,
                "lr_scheduler": lr_scheduler_config_1,
            },
            {
                "optimizer": optimizer2,
                "frequency": FREQUENCY2,
                "lr_scheduler": lr_scheduler_config_2,
            },
        ]

    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx,
        optimizer_closure=None,
        on_tpu: bool = False,
        using_native_amp: bool = False,
        using_lbfgs: bool = False,
    ) -> None:
        print(f"batch{batch_idx} optimizer{optimizer_idx}.step()")
        return super().optimizer_step(
            epoch,
            batch_idx,
            optimizer,
            optimizer_idx,
            optimizer_closure,
            on_tpu,
            using_native_amp,
            using_lbfgs,
        )

    def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
        print(f"       scheduler{optimizer_idx}.step()")
        return super().lr_scheduler_step(scheduler, optimizer_idx, metric)

def main():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    model = BoringModel()
    trainer = Trainer(
        limit_train_batches=LIMIT_TRAIN_BATCHES,
        max_epochs=2,
        enable_progress_bar=False,
        enable_model_summary=False,
        enable_checkpointing=False,
        logger=False,
    )
    trainer.fit(model, train_dataloaders=train_data)

if __name__ == "__main__":
    main()
batch0 optimizer0.step()
batch1 optimizer0.step()
batch2 optimizer0.step()
batch3 optimizer0.step()
       scheduler0.step()
batch0 optimizer0.step()
batch1 optimizer0.step()
batch2 optimizer0.step()
batch3 optimizer0.step()
       scheduler1.step()  # <- should be `scheduler0.step()`

Expected behavior

scheduler1.step() should never be called in the above script.

batch0 optimizer0.step()
batch1 optimizer0.step()
batch2 optimizer0.step()
batch3 optimizer0.step()
       scheduler0.step()
batch0 optimizer0.step()
batch1 optimizer0.step()
batch2 optimizer0.step()
batch3 optimizer0.step()
       scheduler0.step()

Environment

  • PyTorch Lightning Version (e.g., 1.5.0): master
  • PyTorch Version (e.g., 1.10):
  • Python version (e.g., 3.9):
  • OS (e.g., Linux):
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • How you installed PyTorch (conda, pip, source):
  • If compiling from source, the output of torch.__config__.show():
  • Any other relevant information:

Additional context

Found in #11755 where the following test case throws the warning:
https://github.com/PyTorchLightning/pytorch-lightning/blob/9e7bd9c72d0150940138e7ae8f272afcfe56cd01/tests/trainer/optimization/test_optimizers.py#L273-L282

If you think this is a critical issue, please assign this issue to me, and I'll prioritize this :)

cc @tchaton @rohitgr7

@akihironitta akihironitta added bug Something isn't working lr scheduler labels Mar 1, 2022
@rohitgr7
Copy link
Contributor

rohitgr7 commented Mar 2, 2022

@stale
Copy link

stale bot commented Apr 16, 2022

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Apr 16, 2022
@stale stale bot closed this as completed Apr 24, 2022
@akihironitta akihironitta added priority: 2 Low priority task and removed won't fix This will not be worked on labels Jul 5, 2022
@akihironitta akihironitta reopened this Jul 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working lr scheduler priority: 2 Low priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants