From 14a9a93b12dbd54135663b30257878930aab16a0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 2 Sep 2021 18:05:26 +0200 Subject: [PATCH 01/13] Fix plugin closure execution order --- pytorch_lightning/plugins/precision/apex_amp.py | 2 +- .../plugins/precision/deepspeed_precision.py | 2 +- pytorch_lightning/plugins/precision/native_amp.py | 1 + tests/models/test_hooks.py | 9 ++++----- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 297571d1174c8..7f335fc0b58ce 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -97,9 +97,9 @@ def pre_optimizer_step( **kwargs: Any, ) -> bool: """Hook to do something before each optimizer step.""" + lambda_closure() # APEX amp does not support closures super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. - lambda_closure() # APEX amp does not support closures optimizer.step(**kwargs) return False diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index c127a2076455f..1e61f07a61591 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -42,9 +42,9 @@ def pre_optimizer_step( **kwargs: Any, ) -> bool: """Hook to do something before each optimizer step.""" + lambda_closure() # DeepSpeed does not support closures super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. - lambda_closure() # DeepSpeed does not support closures deepspeed_engine = model.trainer.model deepspeed_engine.step() return False diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 9373625f66d02..dc2a2a5cbab33 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -96,6 +96,7 @@ def pre_optimizer_step( " To request, please file a Github issue in PyTorch and tag @mcarilli" ) result = True + # FIXME: is this correct for manual? if model.automatic_optimization: result = lambda_closure() self.scaler.unscale_(optimizer) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 4a787a833dd1c..338fdb4003a6f 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -275,6 +275,7 @@ def _train_batch(self, *args, **kwargs): def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), current_epoch=0, **kwargs): using_native_amp = kwargs.get("amp_backend") == "native" using_deepspeed = kwargs.get("plugins") == "deepspeed" + using_plugin = kwargs.get("amp_backend") or kwargs.get("plugins") out = [] on_before_optimizer_step = [ dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), @@ -290,10 +291,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre dict(name="Callback.on_batch_start", args=(trainer, model)), dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i, 0)), dict(name="on_train_batch_start", args=(ANY, i, 0)), - # these are before the training step because - # they are not part of the `training_step_and_backward` closure, however, - # with native amp, the closure is run first and then the optimizer step. - *(on_before_optimizer_step if not using_native_amp else []), + # without a precision plugin, we execute the closure inside the `optimizer.step` + *([] if using_plugin else on_before_optimizer_step), dict(name="forward", args=(ANY,)), dict(name="training_step", args=(ANY, i)), dict(name="training_step_end", args=(dict(loss=ANY),)), @@ -306,7 +305,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre *([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []), dict(name="Callback.on_after_backward", args=(trainer, model)), dict(name="on_after_backward"), - *(on_before_optimizer_step if using_native_amp else []), + *(on_before_optimizer_step if using_plugin else []), dict( name="optimizer_step", args=(current_epoch, i, ANY, 0, ANY), From 7fe78fdff2593f6a86df78db4518955fb4eaf4fb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 2 Sep 2021 18:10:13 +0200 Subject: [PATCH 02/13] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a10d80b91678..d858282d3533a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -285,6 +285,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333)) +- Fixed the Apex and DeepSpeed plugin closure running after the `on_before_optimizer_step` hook ([#9288](https://github.com/PyTorchLightning/pytorch-lightning/issues/9288)) + + - Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858)) From 73b03d42098a6a200740560ac165f89c57f3245b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 3 Sep 2021 03:45:49 +0200 Subject: [PATCH 03/13] Fix manual optimization on AMP and skipping backward --- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/plugins/precision/apex_amp.py | 9 ++++++--- .../plugins/precision/deepspeed_precision.py | 11 +++++++---- pytorch_lightning/plugins/precision/native_amp.py | 11 +++++------ 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e1b4d1f3f7e58..1004db58bd609 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -633,7 +633,7 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: - ``None`` - Training will skip to the next batch Note: - Returning ``None`` is currently not supported for multi-GPU or TPU, or with 16-bit precision enabled. + Returning ``None`` is currently not supported for multi-GPU or TPU. In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific. diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 7f335fc0b58ce..27fa856d0f4b9 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -97,10 +97,13 @@ def pre_optimizer_step( **kwargs: Any, ) -> bool: """Hook to do something before each optimizer step.""" - lambda_closure() # APEX amp does not support closures + result = lambda_closure() # APEX amp does not support closures super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) - # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. - optimizer.step(**kwargs) + skipped_backward = result is None + # in manual optimization, the closure does not return a value + if not model.automatic_optimization or not skipped_backward: + # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. + optimizer.step(**kwargs) return False def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 1e61f07a61591..b931dbafb7058 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -42,11 +42,14 @@ def pre_optimizer_step( **kwargs: Any, ) -> bool: """Hook to do something before each optimizer step.""" - lambda_closure() # DeepSpeed does not support closures + result = lambda_closure() # DeepSpeed does not support closures super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) - # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. - deepspeed_engine = model.trainer.model - deepspeed_engine.step() + skipped_backward = result is None + # in manual optimization, the closure does not return a value + if not model.automatic_optimization or not skipped_backward: + # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. + deepspeed_engine = model.trainer.model + deepspeed_engine.step() return False def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None: diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index dc2a2a5cbab33..ed15a117ff497 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -95,14 +95,13 @@ def pre_optimizer_step( f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." " To request, please file a Github issue in PyTorch and tag @mcarilli" ) - result = True - # FIXME: is this correct for manual? - if model.automatic_optimization: - result = lambda_closure() + result = lambda_closure() self.scaler.unscale_(optimizer) super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) - # lambda_closure returning None indicates that backward has been skipped - if result is not None: + skipped_backward = result is None + # in manual optimization, the closure does not return a value + if not model.automatic_optimization or not skipped_backward: + # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found self.scaler.step(optimizer) self.scaler.update() return False From d8a57e76bd69340e7bdef28c37833ad5ea02504b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 6 Sep 2021 14:51:39 +0200 Subject: [PATCH 04/13] Fix for deepspeed --- pytorch_lightning/core/lightning.py | 2 +- .../plugins/precision/deepspeed_precision.py | 13 ++++++++----- pytorch_lightning/plugins/precision/native_amp.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 1004db58bd609..299818c4e38d7 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -633,7 +633,7 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: - ``None`` - Training will skip to the next batch Note: - Returning ``None`` is currently not supported for multi-GPU or TPU. + Returning ``None`` is currently not supported for multi-GPU or TPU, or using `DeepSpeed`. In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific. diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index b931dbafb7058..52e788956aaef 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -20,6 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache @@ -44,12 +45,14 @@ def pre_optimizer_step( """Hook to do something before each optimizer step.""" result = lambda_closure() # DeepSpeed does not support closures super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) - skipped_backward = result is None # in manual optimization, the closure does not return a value - if not model.automatic_optimization or not skipped_backward: - # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. - deepspeed_engine = model.trainer.model - deepspeed_engine.step() + if model.automatic_optimization and result is None: + raise MisconfigurationException( + "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`" + ) + # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. + deepspeed_engine = model.trainer.model + deepspeed_engine.step() return False def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None: diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index ed15a117ff497..1d04832b1b60f 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -95,7 +95,7 @@ def pre_optimizer_step( f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." " To request, please file a Github issue in PyTorch and tag @mcarilli" ) - result = lambda_closure() + result = lambda_closure() # native amp does not support closures self.scaler.unscale_(optimizer) super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) skipped_backward = result is None From b945c1df5ee93b20aab04c25a6b1eab0d08cd61b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 6 Sep 2021 15:03:17 +0200 Subject: [PATCH 05/13] Typo --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 299818c4e38d7..558649ca34d4d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -633,7 +633,7 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: - ``None`` - Training will skip to the next batch Note: - Returning ``None`` is currently not supported for multi-GPU or TPU, or using `DeepSpeed`. + Returning ``None`` is currently not supported for multi-GPU or TPU, or using ``DeepSpeed``. In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific. From 5696fb15d91faebec42a8833f3161e195b4094cd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 6 Sep 2021 15:29:11 +0200 Subject: [PATCH 06/13] Hook test for manual closure --- tests/models/test_hooks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 338fdb4003a6f..77d9fce06d6ca 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -321,6 +321,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre @staticmethod def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **kwargs): using_deepspeed = kwargs.get("plugins") == "deepspeed" + using_plugin = kwargs.get("amp_backend") or kwargs.get("plugins") out = [] for i in range(batches): out.extend( @@ -341,8 +342,11 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k dict(name="on_after_backward"), # `manual_backward` calls the previous 3 dict(name="manual_backward", args=(ANY,)), + *([dict(name="closure")] if using_plugin else []), dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), dict(name="on_before_optimizer_step", args=(ANY, 0)), + # without a precision plugin, we execute the closure inside the `optimizer.step` + *([] if using_plugin else [dict(name="closure")]), dict(name="training_step", args=(ANY, i)), dict(name="training_step_end", args=(dict(loss=ANY),)), dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i, 0)), @@ -438,7 +442,7 @@ def training_step(self, batch, batch_idx): opt = self.optimizers() opt.zero_grad() self.manual_backward(loss) - opt.step() + opt.step(lambda: called.append({"name": "closure"})) return {"loss": loss} model = TestModel(called) From 35a7bbcd080962fd53dc52863cb824e99dd222f5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 6 Sep 2021 16:20:20 +0200 Subject: [PATCH 07/13] Add skipping test with AMP --- pytorch_lightning/core/lightning.py | 6 ++---- .../optimization/test_manual_optimization.py | 20 ++++++++++++++++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 558649ca34d4d..8ca87f7b2bb00 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -630,10 +630,8 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: - :class:`~torch.Tensor` - The loss tensor - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'`` - - ``None`` - Training will skip to the next batch - - Note: - Returning ``None`` is currently not supported for multi-GPU or TPU, or using ``DeepSpeed``. + - ``None`` - Training will skip to the next batch. This is only for automatic optimization. + This is not supported for multi-GPU or TPU, or using ``DeepSpeed``. In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific. diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 72ebd62ae499e..690d27138e8ab 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -64,7 +64,14 @@ def configure_optimizers(self): return optimizer, optimizer_2 -def test_multiple_optimizers_manual_no_return(tmpdir): +@pytest.mark.parametrize( + "kwargs", + [ + {}, + pytest.param({"gpus": 1, "precision": 16, "amp_backend": "native"}, marks=RunIf(amp_native=True, min_gpus=1)), + ], +) +def test_multiple_optimizers_manual_no_return(tmpdir, kwargs): class TestModel(ManualOptModel): def training_step(self, batch, batch_idx): # avoid returning a value @@ -86,12 +93,23 @@ def training_epoch_end(self, outputs) -> None: max_epochs=1, log_every_n_steps=1, weights_summary=None, + **kwargs, ) + if kwargs: + step_mock_patch = mock.patch.object( + trainer.precision_plugin.scaler, "step", wraps=trainer.precision_plugin.scaler.step + ) + step_mock = step_mock_patch.start() + with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 + if kwargs: + step_mock_patch.stop() + assert step_mock.call_count == len(model.optimizers()) * limit_train_batches + def test_multiple_optimizers_manual_return(tmpdir): class TestModel(ManualOptModel): From 9b7df18daccc4e84dd056272f14b57ce06742093 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 6 Sep 2021 17:17:06 +0200 Subject: [PATCH 08/13] You are hideous, apex --- .../optimization/test_manual_optimization.py | 77 +++++++++---------- 1 file changed, 36 insertions(+), 41 deletions(-) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 690d27138e8ab..6ba4aa3489c0a 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -69,19 +69,45 @@ def configure_optimizers(self): [ {}, pytest.param({"gpus": 1, "precision": 16, "amp_backend": "native"}, marks=RunIf(amp_native=True, min_gpus=1)), + pytest.param( + {"gpus": 1, "precision": 16, "amp_backend": "apex", "amp_level": "O2"}, + marks=RunIf(amp_apex=True, min_gpus=1), + ), ], ) def test_multiple_optimizers_manual_no_return(tmpdir, kwargs): + apex_optimizer_patches = [] + apex_optimizer_steps = [] + class TestModel(ManualOptModel): def training_step(self, batch, batch_idx): # avoid returning a value super().training_step(batch, batch_idx) - def training_epoch_end(self, outputs) -> None: + def training_epoch_end(self, outputs): # outputs is empty as training_step does not return # and it is not automatic optimization assert not outputs + def on_train_start(self): + if kwargs.get("amp_backend") != "apex": + return + # extremely ugly. APEX patches all the native torch optimizers on `_initialize` which we call on + # `ApexMixedPrecisionPlugin.dispatch`. Additionally, their replacement `new_step` functions are locally + # defined so can't even patch those, thus we need to create the mock after APEX has been initialized + nonlocal apex_optimizer_patches, apex_optimizer_steps + for opt in self.trainer.optimizers: + # `amp.scale_loss` will also patch the step to avoid it when gradient overflow happens. avoid it + opt._amp_stash.already_patched = True + patch = mock.patch.object(opt, "step") + apex_optimizer_patches.append(patch) + apex_optimizer_steps.append(patch.start()) + + def on_train_end(self): + if kwargs.get("amp_backend") == "apex": + for p in apex_optimizer_patches: + p.stop() + model = TestModel() model.val_dataloader = None @@ -96,19 +122,22 @@ def training_epoch_end(self, outputs) -> None: **kwargs, ) - if kwargs: - step_mock_patch = mock.patch.object( + if kwargs.get("amp_backend") == "native": + # mock the scaler instead of the optimizer step because it can be skipped with NaNs + scaler_step_patch = mock.patch.object( trainer.precision_plugin.scaler, "step", wraps=trainer.precision_plugin.scaler.step ) - step_mock = step_mock_patch.start() + scaler_step = scaler_step_patch.start() with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 - if kwargs: - step_mock_patch.stop() - assert step_mock.call_count == len(model.optimizers()) * limit_train_batches + if kwargs.get("amp_backend") == "native": + scaler_step_patch.stop() + assert scaler_step.call_count == len(model.optimizers()) * limit_train_batches + if kwargs.get("amp_backend") == "apex": + assert [s.call_count for s in apex_optimizer_steps] == [len(model.optimizers())] * limit_train_batches def test_multiple_optimizers_manual_return(tmpdir): @@ -189,40 +218,6 @@ def test_multiple_optimizers_manual_native_amp(tmpdir): assert bwd_mock.call_count == limit_train_batches * 3 -@RunIf(min_gpus=1, amp_apex=True) -def test_multiple_optimizers_manual_apex_no_return(tmpdir): - class TestModel(ManualOptModel): - def training_step(self, batch, batch_idx): - # avoid returning a value - super().training_step(batch, batch_idx) - - def training_epoch_end(self, outputs) -> None: - # outputs is empty as training_step does not return - # and it is not automatic optimization - assert len(outputs) == 0 - - model = TestModel() - model.val_dataloader = None - - limit_train_batches = 2 - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=limit_train_batches, - limit_val_batches=2, - max_epochs=1, - log_every_n_steps=1, - weights_summary=None, - precision=16, - amp_level="O2", - amp_backend="apex", - gpus=1, - ) - - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: - trainer.fit(model) - assert bwd_mock.call_count == limit_train_batches * 3 - - class ManualOptimizationExtendedModel(BoringModel): count = 0 From 1ba6432d9874435ce3d4ccab59979989f93a5b85 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 6 Sep 2021 17:29:14 +0200 Subject: [PATCH 09/13] Add deepspeed test --- tests/plugins/test_deepspeed_plugin.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 9b4a1f8a4ba99..750c6926b8eb2 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -791,3 +791,14 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir): trainer.fit(model) _assert_save_model_is_equal(model, tmpdir, trainer) + + +def test_deepspeed_skip_backward_raises(tmpdir): + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + return None + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, plugins=[DeepSpeedPlugin()], gpus=1, fast_dev_run=True, precision=16) + with pytest.raises(MisconfigurationException, match="returning `None` .* is not supported"): + trainer.fit(model) From 33ecf1332b80cfb913ef819280b90f301adc52d9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 6 Sep 2021 17:31:18 +0200 Subject: [PATCH 10/13] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d858282d3533a..aa3669c40ccf3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -288,6 +288,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the Apex and DeepSpeed plugin closure running after the `on_before_optimizer_step` hook ([#9288](https://github.com/PyTorchLightning/pytorch-lightning/issues/9288)) +- Fixed the Native AMP plugin closure not running with manual optimization ([#9288](https://github.com/PyTorchLightning/pytorch-lightning/issues/9288)) + + - Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858)) From d09e753ef257b1ff08dc0bf8f66f21dd78c9889b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 6 Sep 2021 17:33:44 +0200 Subject: [PATCH 11/13] Fix for broken master --- tests/trainer/loops/test_training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index c37681e4831ca..d21e8efc7a5cb 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -191,7 +191,7 @@ def training_epoch_end(self, outputs) -> None: def test_batch_loop_releases_loss(tmpdir): - """Test that loss/graph is released so that it can be garbage collected before the next training step""" + """Test that loss/graph is released so that it can be garbage collected before the next training step.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): From df99e8d6df84f48812a8ddd61001c40a3633bfe7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 6 Sep 2021 17:51:03 +0200 Subject: [PATCH 12/13] Add RunIf --- tests/plugins/test_deepspeed_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 750c6926b8eb2..71cab78d90b49 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -793,6 +793,7 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir): _assert_save_model_is_equal(model, tmpdir, trainer) +@RunIf(min_gpus=2, deepspeed=True) def test_deepspeed_skip_backward_raises(tmpdir): class TestModel(BoringModel): def training_step(self, batch, batch_idx): From 2db195d2017f9c8c1d619efe19b58e91530d17e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 7 Sep 2021 12:13:40 +0200 Subject: [PATCH 13/13] Update tests/plugins/test_deepspeed_plugin.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- tests/plugins/test_deepspeed_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 71cab78d90b49..212289767f2c5 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -793,7 +793,7 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir): _assert_save_model_is_equal(model, tmpdir, trainer) -@RunIf(min_gpus=2, deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True) def test_deepspeed_skip_backward_raises(tmpdir): class TestModel(BoringModel): def training_step(self, batch, batch_idx):