From 58bb4c94dd38d414827d2351bdb7297159a9ae7d Mon Sep 17 00:00:00 2001 From: Rustam Zhumagambetov Date: Mon, 28 Apr 2025 15:52:04 +0200 Subject: [PATCH 1/5] add toggled_optimizer to LightningModule --- src/lightning/pytorch/core/module.py | 24 +++++++++++++++++++ .../core/test_lightning_module.py | 15 ++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b8624daac3fa3..9f06d9b14de64 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1141,6 +1141,30 @@ def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> # save memory self._param_requires_grad_state = {} + @contextmanager + def toggled_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> Generator: + """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to + prevent dangling gradients in multiple-optimizer setup. Combines :meth:`toggle_optimizer` and + :meth:`untoggle_optimizer` into context manager. + + Args: + optimizer: The optimizer to untoggle. + + Example:: + + def training_step(...): + opt = self.optimizers() + with self.toggled_optimizer(opt): + loss = ... + opt.zero_grad() + self.manual_backward(loss) + opt.step() + """ + try: + yield self.toggle_optimizer(optimizer) + finally: + self.untoggle_optimizer(optimizer) + def clip_gradients( self, optimizer: Optimizer, diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 2036014762ebf..513d2f2cfc683 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -118,6 +118,21 @@ def test_1_optimizer_toggle_model(): model.untoggle_optimizer(optimizer) assert not model._param_requires_grad_state +def test_1_optimizer_toggle_model_context_manager(): + """Test toggle_model runs when only one optimizer is used.""" + model = BoringModel() + trainer = Mock() + model.trainer = trainer + params = model.parameters() + optimizer = torch.optim.SGD(params, lr=0.1) + trainer.optimizers = [optimizer] + + assert not model._param_requires_grad_state + # toggle optimizer was failing with a single optimizer + with model.toggled_optimizer(optimizer): + assert model._param_requires_grad_state + assert not model._param_requires_grad_state + def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmp_path): class TestModel(BoringModel): From b31134ae1b965d417ad014ab283cb4a0fe4d3bda Mon Sep 17 00:00:00 2001 From: Rustam Zhumagambetov Date: Mon, 28 Apr 2025 15:52:30 +0200 Subject: [PATCH 2/5] add docs for toggled_optimizer to LightningModule --- docs/source-pytorch/conf.py | 1 + docs/source-pytorch/model/manual_optimization.rst | 2 +- src/lightning/pytorch/CHANGELOG.md | 2 ++ src/lightning/pytorch/core/module.py | 2 +- 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py index 90400b1df491d..62cd21fc127f4 100644 --- a/docs/source-pytorch/conf.py +++ b/docs/source-pytorch/conf.py @@ -487,6 +487,7 @@ def _load_py_module(name: str, location: str) -> ModuleType: ("py:meth", "setup"), ("py:meth", "test_step"), ("py:meth", "toggle_optimizer"), + ("py:meth", "toggled_optimizer"), ("py:class", "torch.ScriptModule"), ("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload"), ("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.MixedPrecision"), diff --git a/docs/source-pytorch/model/manual_optimization.rst b/docs/source-pytorch/model/manual_optimization.rst index 150f04793eae6..4c7400c0457ca 100644 --- a/docs/source-pytorch/model/manual_optimization.rst +++ b/docs/source-pytorch/model/manual_optimization.rst @@ -17,7 +17,7 @@ To manually optimize, do the following: * ``optimizer.zero_grad()`` to clear the gradients from the previous training step * ``self.manual_backward(loss)`` instead of ``loss.backward()`` * ``optimizer.step()`` to update your model parameters - * ``self.toggle_optimizer()`` and ``self.untoggle_optimizer()`` if needed + * ``self.toggle_optimizer()`` and ``self.untoggle_optimizer()``, or ``self.toggled_optimizer()`` if needed Here is a minimal example of manual optimization. diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 5616defeffc8a..514d195db8e90 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593)) +- Add `toggled_optimizer(optimizer)` method to the LightningModule, which is a context manager version of `toggle_optimize` and `untoggle_optimizer` + ### Changed diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 9f06d9b14de64..188b9e548cd5f 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1148,7 +1148,7 @@ def toggled_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> :meth:`untoggle_optimizer` into context manager. Args: - optimizer: The optimizer to untoggle. + optimizer: The optimizer to toggle. Example:: From 018a50efa0eb015e4c66428f4bde6b0b0caf1592 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:53:52 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/CHANGELOG.md | 2 +- src/lightning/pytorch/core/module.py | 3 ++- tests/tests_pytorch/core/test_lightning_module.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 514d195db8e90..fe77a15ae360e 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -11,7 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593)) -- Add `toggled_optimizer(optimizer)` method to the LightningModule, which is a context manager version of `toggle_optimize` and `untoggle_optimizer` +- Add `toggled_optimizer(optimizer)` method to the LightningModule, which is a context manager version of `toggle_optimize` and `untoggle_optimizer` ### Changed diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 188b9e548cd5f..4ec1679b9eb7a 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1144,7 +1144,7 @@ def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> @contextmanager def toggled_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> Generator: """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to - prevent dangling gradients in multiple-optimizer setup. Combines :meth:`toggle_optimizer` and + prevent dangling gradients in multiple-optimizer setup. Combines :meth:`toggle_optimizer` and :meth:`untoggle_optimizer` into context manager. Args: @@ -1159,6 +1159,7 @@ def training_step(...): opt.zero_grad() self.manual_backward(loss) opt.step() + """ try: yield self.toggle_optimizer(optimizer) diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 513d2f2cfc683..25458692c175f 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -118,6 +118,7 @@ def test_1_optimizer_toggle_model(): model.untoggle_optimizer(optimizer) assert not model._param_requires_grad_state + def test_1_optimizer_toggle_model_context_manager(): """Test toggle_model runs when only one optimizer is used.""" model = BoringModel() From ca2e66cbe17430ba4fa3bd9817b848b43afb1856 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 10 Jun 2025 13:16:31 +0200 Subject: [PATCH 4/5] Apply suggestions from code review --- src/lightning/pytorch/core/module.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 4ec1679b9eb7a..2cd93b695ebbb 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1161,8 +1161,9 @@ def training_step(...): opt.step() """ + self.toggle_optimizer(optimizer) try: - yield self.toggle_optimizer(optimizer) + yield finally: self.untoggle_optimizer(optimizer) From a732ad864861cdaf46252e05e21dfdd13fe0e145 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 10 Jun 2025 13:17:43 +0200 Subject: [PATCH 5/5] Apply suggestions from code review --- tests/tests_pytorch/core/test_lightning_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 25458692c175f..c33488a4f2626 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -119,7 +119,7 @@ def test_1_optimizer_toggle_model(): assert not model._param_requires_grad_state -def test_1_optimizer_toggle_model_context_manager(): +def test_optimizer_toggle_model_context_manager(): """Test toggle_model runs when only one optimizer is used.""" model = BoringModel() trainer = Mock()