diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a10d80b91678..aa3669c40ccf3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -285,6 +285,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333)) +- Fixed the Apex and DeepSpeed plugin closure running after the `on_before_optimizer_step` hook ([#9288](https://github.com/PyTorchLightning/pytorch-lightning/issues/9288)) + + +- Fixed the Native AMP plugin closure not running with manual optimization ([#9288](https://github.com/PyTorchLightning/pytorch-lightning/issues/9288)) + + - Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e1b4d1f3f7e58..8ca87f7b2bb00 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -630,10 +630,8 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: - :class:`~torch.Tensor` - The loss tensor - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'`` - - ``None`` - Training will skip to the next batch - - Note: - Returning ``None`` is currently not supported for multi-GPU or TPU, or with 16-bit precision enabled. + - ``None`` - Training will skip to the next batch. This is only for automatic optimization. + This is not supported for multi-GPU or TPU, or using ``DeepSpeed``. In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific. diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 297571d1174c8..27fa856d0f4b9 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -97,10 +97,13 @@ def pre_optimizer_step( **kwargs: Any, ) -> bool: """Hook to do something before each optimizer step.""" + result = lambda_closure() # APEX amp does not support closures super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) - # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. - lambda_closure() # APEX amp does not support closures - optimizer.step(**kwargs) + skipped_backward = result is None + # in manual optimization, the closure does not return a value + if not model.automatic_optimization or not skipped_backward: + # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. + optimizer.step(**kwargs) return False def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index c127a2076455f..52e788956aaef 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -20,6 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache @@ -42,9 +43,14 @@ def pre_optimizer_step( **kwargs: Any, ) -> bool: """Hook to do something before each optimizer step.""" + result = lambda_closure() # DeepSpeed does not support closures super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) + # in manual optimization, the closure does not return a value + if model.automatic_optimization and result is None: + raise MisconfigurationException( + "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`" + ) # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. - lambda_closure() # DeepSpeed does not support closures deepspeed_engine = model.trainer.model deepspeed_engine.step() return False diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 9373625f66d02..1d04832b1b60f 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -95,13 +95,13 @@ def pre_optimizer_step( f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." " To request, please file a Github issue in PyTorch and tag @mcarilli" ) - result = True - if model.automatic_optimization: - result = lambda_closure() + result = lambda_closure() # native amp does not support closures self.scaler.unscale_(optimizer) super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) - # lambda_closure returning None indicates that backward has been skipped - if result is not None: + skipped_backward = result is None + # in manual optimization, the closure does not return a value + if not model.automatic_optimization or not skipped_backward: + # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found self.scaler.step(optimizer) self.scaler.update() return False diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 4a787a833dd1c..77d9fce06d6ca 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -275,6 +275,7 @@ def _train_batch(self, *args, **kwargs): def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), current_epoch=0, **kwargs): using_native_amp = kwargs.get("amp_backend") == "native" using_deepspeed = kwargs.get("plugins") == "deepspeed" + using_plugin = kwargs.get("amp_backend") or kwargs.get("plugins") out = [] on_before_optimizer_step = [ dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), @@ -290,10 +291,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre dict(name="Callback.on_batch_start", args=(trainer, model)), dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i, 0)), dict(name="on_train_batch_start", args=(ANY, i, 0)), - # these are before the training step because - # they are not part of the `training_step_and_backward` closure, however, - # with native amp, the closure is run first and then the optimizer step. - *(on_before_optimizer_step if not using_native_amp else []), + # without a precision plugin, we execute the closure inside the `optimizer.step` + *([] if using_plugin else on_before_optimizer_step), dict(name="forward", args=(ANY,)), dict(name="training_step", args=(ANY, i)), dict(name="training_step_end", args=(dict(loss=ANY),)), @@ -306,7 +305,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre *([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []), dict(name="Callback.on_after_backward", args=(trainer, model)), dict(name="on_after_backward"), - *(on_before_optimizer_step if using_native_amp else []), + *(on_before_optimizer_step if using_plugin else []), dict( name="optimizer_step", args=(current_epoch, i, ANY, 0, ANY), @@ -322,6 +321,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre @staticmethod def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **kwargs): using_deepspeed = kwargs.get("plugins") == "deepspeed" + using_plugin = kwargs.get("amp_backend") or kwargs.get("plugins") out = [] for i in range(batches): out.extend( @@ -342,8 +342,11 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k dict(name="on_after_backward"), # `manual_backward` calls the previous 3 dict(name="manual_backward", args=(ANY,)), + *([dict(name="closure")] if using_plugin else []), dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), dict(name="on_before_optimizer_step", args=(ANY, 0)), + # without a precision plugin, we execute the closure inside the `optimizer.step` + *([] if using_plugin else [dict(name="closure")]), dict(name="training_step", args=(ANY, i)), dict(name="training_step_end", args=(dict(loss=ANY),)), dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i, 0)), @@ -439,7 +442,7 @@ def training_step(self, batch, batch_idx): opt = self.optimizers() opt.zero_grad() self.manual_backward(loss) - opt.step() + opt.step(lambda: called.append({"name": "closure"})) return {"loss": loss} model = TestModel(called) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 9b4a1f8a4ba99..212289767f2c5 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -791,3 +791,15 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir): trainer.fit(model) _assert_save_model_is_equal(model, tmpdir, trainer) + + +@RunIf(min_gpus=1, deepspeed=True) +def test_deepspeed_skip_backward_raises(tmpdir): + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + return None + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, plugins=[DeepSpeedPlugin()], gpus=1, fast_dev_run=True, precision=16) + with pytest.raises(MisconfigurationException, match="returning `None` .* is not supported"): + trainer.fit(model) diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index c37681e4831ca..d21e8efc7a5cb 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -191,7 +191,7 @@ def training_epoch_end(self, outputs) -> None: def test_batch_loop_releases_loss(tmpdir): - """Test that loss/graph is released so that it can be garbage collected before the next training step""" + """Test that loss/graph is released so that it can be garbage collected before the next training step.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 72ebd62ae499e..6ba4aa3489c0a 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -64,17 +64,50 @@ def configure_optimizers(self): return optimizer, optimizer_2 -def test_multiple_optimizers_manual_no_return(tmpdir): +@pytest.mark.parametrize( + "kwargs", + [ + {}, + pytest.param({"gpus": 1, "precision": 16, "amp_backend": "native"}, marks=RunIf(amp_native=True, min_gpus=1)), + pytest.param( + {"gpus": 1, "precision": 16, "amp_backend": "apex", "amp_level": "O2"}, + marks=RunIf(amp_apex=True, min_gpus=1), + ), + ], +) +def test_multiple_optimizers_manual_no_return(tmpdir, kwargs): + apex_optimizer_patches = [] + apex_optimizer_steps = [] + class TestModel(ManualOptModel): def training_step(self, batch, batch_idx): # avoid returning a value super().training_step(batch, batch_idx) - def training_epoch_end(self, outputs) -> None: + def training_epoch_end(self, outputs): # outputs is empty as training_step does not return # and it is not automatic optimization assert not outputs + def on_train_start(self): + if kwargs.get("amp_backend") != "apex": + return + # extremely ugly. APEX patches all the native torch optimizers on `_initialize` which we call on + # `ApexMixedPrecisionPlugin.dispatch`. Additionally, their replacement `new_step` functions are locally + # defined so can't even patch those, thus we need to create the mock after APEX has been initialized + nonlocal apex_optimizer_patches, apex_optimizer_steps + for opt in self.trainer.optimizers: + # `amp.scale_loss` will also patch the step to avoid it when gradient overflow happens. avoid it + opt._amp_stash.already_patched = True + patch = mock.patch.object(opt, "step") + apex_optimizer_patches.append(patch) + apex_optimizer_steps.append(patch.start()) + + def on_train_end(self): + if kwargs.get("amp_backend") == "apex": + for p in apex_optimizer_patches: + p.stop() + model = TestModel() model.val_dataloader = None @@ -86,12 +119,26 @@ def training_epoch_end(self, outputs) -> None: max_epochs=1, log_every_n_steps=1, weights_summary=None, + **kwargs, ) + if kwargs.get("amp_backend") == "native": + # mock the scaler instead of the optimizer step because it can be skipped with NaNs + scaler_step_patch = mock.patch.object( + trainer.precision_plugin.scaler, "step", wraps=trainer.precision_plugin.scaler.step + ) + scaler_step = scaler_step_patch.start() + with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 + if kwargs.get("amp_backend") == "native": + scaler_step_patch.stop() + assert scaler_step.call_count == len(model.optimizers()) * limit_train_batches + if kwargs.get("amp_backend") == "apex": + assert [s.call_count for s in apex_optimizer_steps] == [len(model.optimizers())] * limit_train_batches + def test_multiple_optimizers_manual_return(tmpdir): class TestModel(ManualOptModel): @@ -171,40 +218,6 @@ def test_multiple_optimizers_manual_native_amp(tmpdir): assert bwd_mock.call_count == limit_train_batches * 3 -@RunIf(min_gpus=1, amp_apex=True) -def test_multiple_optimizers_manual_apex_no_return(tmpdir): - class TestModel(ManualOptModel): - def training_step(self, batch, batch_idx): - # avoid returning a value - super().training_step(batch, batch_idx) - - def training_epoch_end(self, outputs) -> None: - # outputs is empty as training_step does not return - # and it is not automatic optimization - assert len(outputs) == 0 - - model = TestModel() - model.val_dataloader = None - - limit_train_batches = 2 - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=limit_train_batches, - limit_val_batches=2, - max_epochs=1, - log_every_n_steps=1, - weights_summary=None, - precision=16, - amp_level="O2", - amp_backend="apex", - gpus=1, - ) - - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: - trainer.fit(model) - assert bwd_mock.call_count == limit_train_batches * 3 - - class ManualOptimizationExtendedModel(BoringModel): count = 0