Skip to content

Commit

Permalink
Update LR schedulers only when their corresponding Optimizer is being… (
Browse files Browse the repository at this point in the history
#4868)

* Update LR schedulers only when their corresponding Optimizer is being used.

In the case when optimizer frequencies are specified,
the LR scheduler corresponding to a particular optimizer is updated
only when that optimizer is being used in the training loop or epoch.

* pep8speak fixes

* Fix failing tests

* Add docs

* PR Feedback

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <[email protected]>

* formatting fix

* PR Feedback - part 2

* More PR feedback

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <[email protected]>

* Add typing imports

* Stronger tests and fixes related to that

* Add more tests plus PR feedback

* Make optimizer_freq_cumsum a cached property

@cached_property is only available after Python 3.8 so had to do it manually.

* Fix tests

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <[email protected]>

* Avoid mutable defaults

* Parametrize lr scheduling tests

* PR feedback

* Apply suggestions from code review

* spell

* Apply suggestions from code review

* flake8

Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: chaton <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored May 4, 2021
1 parent b780af5 commit 82c19e1
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 151 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bug where `BaseFinetuning.flatten_modules()` was duplicating leaf node parameters ([#6879](https://github.com/PyTorchLightning/pytorch-lightning/pull/6879))


- Fixed bug where the learning rate schedulers did not follow the optimizer frequencies ([#4868](https://github.com/PyTorchLightning/pytorch-lightning/pull/4868))


- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))


Expand Down
29 changes: 27 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,8 +1109,33 @@ def configure_optimizers(self):
- **None** - Fit will run without any optimizer.
Note:
The lr_dict is a dictionary which contains the scheduler and its associated configuration. The default
configuration is shown below.
The ``frequency`` value specified in a dict along with the ``optimizer`` key is an int corresponding
to the number of sequential batches optimized with the specific optimizer.
It should be given to none or to all of the optimizers.
There is a difference between passing multiple optimizers in a list,
and passing multiple optimizers in dictionaries with a frequency of 1:
In the former case, all optimizers will operate on the given batch in each optimization step.
In the latter, only one optimizer will operate on the given batch at every step.
This is different from the ``frequency`` value specified in the lr_dict mentioned below.
.. code-block:: python
def configure_optimizers(self):
optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01)
optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01)
return [
{'optimizer': optimizer_one, 'frequency': 5},
{'optimizer': optimizer_two, 'frequency': 10},
]
In this example, the first optimizer will be used for the first 5 steps,
the second optimizer for the next 10 steps and that cycle will continue.
If an LR scheduler is specified for an optimizer using the ``lr_scheduler`` key in the above dict,
the scheduler will only be updated when its optimizer is being used.
Note:
The lr_dict is a dictionary which contains the scheduler and its associated configuration.
The default configuration is shown below.
.. code-block:: python
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand All @@ -25,7 +27,9 @@ def on_trainer_init(self):
self.trainer.optimizers = []
self.trainer.optimizer_frequencies = []

def update_learning_rates(self, interval: str, monitor_metrics=None):
def update_learning_rates(
self, interval: str, monitor_metrics: Optional[Dict[str, Any]] = None, opt_indices: Optional[List[int]] = None
):
"""Update learning rates.
Args:
Expand All @@ -35,7 +39,13 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
return

if opt_indices is None:
opt_indices = []

for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers):
if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices:
continue

current_idx = self.trainer.batch_idx if interval == 'step' else self.trainer.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
Expand Down
20 changes: 17 additions & 3 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
if isinstance(optim_conf, Optimizer):
optimizers = [optim_conf]
# two lists, optimizer + lr schedulers
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list):
elif (
isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list)
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
):
opt, sch = optim_conf
optimizers = opt
lr_schedulers = sch if isinstance(sch, list) else [sch]
Expand All @@ -58,15 +61,25 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if "lr_scheduler" in opt_dict]
scheduler_dict = (
lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx) if isinstance(scheduler, dict) else {
'scheduler': scheduler,
'opt_idx': opt_idx
}
)

lr_schedulers = [
scheduler_dict(opt_dict["lr_scheduler"], opt_idx) for opt_idx, opt_dict in enumerate(optim_conf)
if "lr_scheduler" in opt_dict
]
optimizer_frequencies = [
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
]
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)):
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
optimizers = list(optim_conf)
# unknown configuration
else:
Expand Down Expand Up @@ -207,4 +220,5 @@ def _get_default_scheduler_config() -> Dict[str, Any]:
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': None, # value to monitor for ReduceLROnPlateau
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
'opt_idx': None, # necessary to store opt_idx when optimizer frequencies are specified
}
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,15 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:

# update epoch-level lr_schedulers
if on_epoch:
self.optimizer_connector.update_learning_rates(interval='epoch')
self.optimizer_connector.update_learning_rates(
interval='epoch',
opt_indices=[
opt_idx
for opt_idx, _ in self.train_loop.get_optimizers_iterable(batch_idx=(
self.total_batch_idx - 1
)) # Select the optimizers which were used in the last batch of the epoch
],
)

# hook
self.evaluation_loop.on_evaluation_end()
Expand Down
25 changes: 19 additions & 6 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, trainer, multiple_trainloader_mode: str):
self._multiple_trainloader_mode = multiple_trainloader_mode
self._skip_backward = False
self.trainer._multiple_trainloader_mode = multiple_trainloader_mode
self._optimizer_freq_cumsum = None

def on_trainer_init(
self,
Expand Down Expand Up @@ -83,6 +84,12 @@ def num_optimizers(self):
num_optimizers = len(self.get_optimizers_iterable())
return num_optimizers

@property
def optimizer_freq_cumsum(self):
if self._optimizer_freq_cumsum is None:
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
return self._optimizer_freq_cumsum

def should_skip_training(self):
should_by_max_steps = self.trainer.max_steps is not None and self.trainer.global_step >= self.trainer.max_steps
should_by_epoch = self.trainer.max_epochs is not None and self.trainer.current_epoch >= self.trainer.max_epochs
Expand Down Expand Up @@ -211,20 +218,22 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):

epoch_output[opt_idx].append(opt_outputs)

def get_optimizers_iterable(self):
def get_optimizers_iterable(self, batch_idx=None):
"""
Generates an iterable with (idx, optimizer) for each optimizer.
"""
if not self.trainer.optimizer_frequencies:
# call training_step once per optimizer
return list(enumerate(self.trainer.optimizers))

optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
optimizers_loop_length = optimizer_freq_cumsum[-1]
current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length
if batch_idx is None:
batch_idx = self.trainer.total_batch_idx

optimizers_loop_length = self.optimizer_freq_cumsum[-1]
current_place_in_loop = batch_idx % optimizers_loop_length

# find optimzier index by looking for the first {item > current_place} in the cumsum list
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
opt_idx = np.argmax(self.optimizer_freq_cumsum > current_place_in_loop)
return [[opt_idx, self.trainer.optimizers[opt_idx]]]

def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
Expand Down Expand Up @@ -801,7 +810,11 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None):

if num_accumulated_batches_reached or num_training_batches_reached:
# update lr
self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics)
self.trainer.optimizer_connector.update_learning_rates(
interval="step",
monitor_metrics=monitor_metrics,
opt_indices=[opt_idx for opt_idx, _ in self.get_optimizers_iterable()],
)

def increment_accumulated_grad_global_step(self):
num_accumulated_batches_reached = self._accumulated_batches_reached()
Expand Down
11 changes: 0 additions & 11 deletions tests/base/model_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer

def configure_optimizers__empty(self):
return None

def configure_optimizers__lbfgs(self):
"""
return whatever optimizers we want here.
Expand All @@ -41,14 +38,6 @@ def configure_optimizers__adagrad(self):
optimizer = optim.Adagrad(self.parameters(), lr=self.learning_rate)
return optimizer

def configure_optimizers__multiple_optimizers_frequency(self):
optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate)
return [
dict(optimizer=optimizer1, frequency=1),
dict(optimizer=optimizer2, frequency=5),
]

def configure_optimizers__single_scheduler(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
Expand Down
Loading

0 comments on commit 82c19e1

Please sign in to comment.