Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

OneCycleLR scheduler does not work with freeze-unfreeze finetuning strategy #1321

Closed
marrrcin opened this issue Apr 29, 2022 · 0 comments · Fixed by #1329
Closed

OneCycleLR scheduler does not work with freeze-unfreeze finetuning strategy #1321

marrrcin opened this issue Apr 29, 2022 · 0 comments · Fixed by #1329
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Milestone

Comments

@marrrcin
Copy link

🐛 Bug

I wanted to create an image classifier by fine-tuning pre-trained model on my dataset. When OneCycleLR scheduler is used alongside the freeze-unfreeze, training throws an exception once the unfreeze epoch is reached.

To Reproduce / Code Sample

I use flash's built-in ImageClassifier as follows:

   epochs = 50
   model = ImageClassifier(
        backbone="efficientnet_b5",
        labels=datamodule.labels,
        metrics=[
            Accuracy(),
        ],
        optimizer="AdamW",
        lr_scheduler=(
            "onecyclelr",
            {
                "max_lr": 1e-3,
                "epochs": epochs,
                "steps_per_epoch": steps_per_epoch,
            },
            {"interval": "step"},
        ),
    )
trainer = flash.Trainer(
    max_epochs=epochs,
    gpus=torch.cuda.device_count(),
)
trainer.finetune(model, datamodule=datamodule, strategy=("freeze_unfreeze", 5))

Expected behaviour

After specified number of epochs, layers get unfrozen and training continues.

Actual behaviour

Expection is thrown:

  File "/<redacted>/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/<redacted>/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
    return self._run_train()
  File "/<redacted>/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1319, in _run_train
    self.fit_loop.run()
  File "/<redacted>/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/<redacted>/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
    self.epoch_loop.run(data_fetcher)
  File "/<redacted>/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/<redacted>/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 199, in advance
    self.update_lr_schedulers("step", update_plateau_schedulers=False)
  File "/<redacted>/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 441, in update_lr_schedulers
    self._update_learning_rates(
  File "/<redacted>/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 505, in _update_learning_rates
    lr_scheduler["scheduler"].step()
  File "/<redacted>/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 154, in step
    values = self.get_lr()
  File "/<redacted>/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 1597, in get_lr
    computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct)
KeyError: 'max_lr

It seems like the unfreezing strategy creates additional optimizer parameter groups, but when the unfreezing happens, some of the LR scheduler parameters are not copied / passed to the new param group properly in: pytorch_lightning.callbacks.finetuning.BaseFinetuning.unfreeze_and_add_param_group.

Environment

  • OS (e.g., Linux): macOS
  • Python version: 3.8.12
  • PyTorch/Lightning/Flash Version (e.g., 1.10/1.5/0.7): 1.11.0 / 1.5.10 / 0.7.3
  • GPU models and configuration: 0 / 1 T4 (happens regardless of cuda)
  • Any other relevant information:

Additional context

https://pytorch-lightning.slack.com/archives/CRBLFHY79/p1651218144224359

@marrrcin marrrcin added bug / fix Something isn't working help wanted Extra attention is needed labels Apr 29, 2022
@ethanwharris ethanwharris added this to the 0.7.x milestone May 4, 2022
@ethanwharris ethanwharris self-assigned this May 5, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants