diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e8e1749f4350..1d8dc8647bc45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -259,6 +259,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `PrecisionPlugin.backward` hooks no longer takes a `should_accumulate` argument ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) +- Added the `on_before_backward` hook ([#7865](https://github.com/PyTorchLightning/pytorch-lightning/pull/7865)) + + - `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 6043eab649ebf..40c0ef92a8e3d 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1191,6 +1191,7 @@ for more information. on_before_zero_grad() optimizer_zero_grad() + on_before_backward() backward() on_after_backward() @@ -1246,6 +1247,12 @@ get_progress_bar_dict .. automethod:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict :noindex: +on_before_backward +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_backward + :noindex: + on_after_backward ~~~~~~~~~~~~~~~~~ diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index 01bb02f5a74b9..a905958aac00b 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -351,6 +351,12 @@ on_load_checkpoint .. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint :noindex: +on_before_backward +^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_before_backward + :noindex: + on_after_backward ^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index c35fc64e2e115..5f3fc5b8d0562 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -19,6 +19,7 @@ import abc from typing import Any, Dict, List, Optional +import torch from torch.optim import Optimizer import pytorch_lightning as pl @@ -296,6 +297,10 @@ def on_load_checkpoint( """ pass + def on_before_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', loss: torch.Tensor) -> None: + """Called before ``loss.backward()``.""" + pass + def on_after_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called after ``loss.backward()`` and before optimizers do anything.""" pass diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 33250fdd044f8..e10274bd59b06 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -77,6 +77,7 @@ def __init__( on_keyboard_interrupt: Optional[Callable] = None, on_save_checkpoint: Optional[Callable] = None, on_load_checkpoint: Optional[Callable] = None, + on_before_backward: Optional[Callable] = None, on_after_backward: Optional[Callable] = None, on_before_zero_grad: Optional[Callable] = None, on_predict_start: Optional[Callable] = None, diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index e0b4f7c74e477..c99350879aae7 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -295,6 +295,15 @@ def on_before_zero_grad(self, optimizer: Optimizer) -> None: optimizer: The optimizer for which grads should be zeroed. """ + def on_before_backward(self, loss: torch.Tensor) -> None: + """ + Called before ``loss.backward()``. + + Args: + loss: Loss divided by number of batches for gradient accumulation and scaled if using native AMP. + """ + pass + def on_after_backward(self) -> None: """ Called in the training loop after loss.backward() and before optimizers do anything. diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 3e37de6b5d07b..c40025dd1d935 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -42,7 +42,8 @@ def pre_backward( model: 'pl.LightningModule', closure_loss: torch.Tensor, ) -> torch.Tensor: - return self.scaler.scale(closure_loss) + closure_loss = self.scaler.scale(closure_loss) + return super().pre_backward(model, closure_loss) def pre_optimizer_step( self, diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index edba2e4b15240..fea02f87baa57 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -62,6 +62,7 @@ def pre_backward( model: the model to be optimized closure_loss: the loss value obtained from the closure """ + model.trainer.call_hook("on_before_backward", closure_loss) return closure_loss def backward( diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 4f4e44e57d3a3..098507569499f 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -17,6 +17,8 @@ from inspect import signature from typing import Any, Callable, Dict, List, Optional, Type +import torch + import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn @@ -313,6 +315,11 @@ def on_load_checkpoint(self, checkpoint): else: callback.on_load_checkpoint(self, self.lightning_module, state) + def on_before_backward(self, loss: torch.Tensor) -> None: + """Called before ``loss.backward()``.""" + for callback in self.callbacks: + callback.on_before_backward(self, self.lightning_module, loss) + def on_after_backward(self): """ Called after loss.backward() and before optimizers do anything. diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 8d079f8b4a637..d66b069817918 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -21,6 +21,7 @@ class FxValidator: functions: Dict[str, Optional[Dict[str, Tuple[bool]]]] = dict( on_before_accelerator_backend_setup=None, on_configure_sharded_model=None, + on_before_backward=dict(on_step=(False, True), on_epoch=(False, True)), on_after_backward=dict(on_step=(False, True), on_epoch=(False, True)), on_before_zero_grad=dict(on_step=(False, True), on_epoch=(False, True)), on_init_start=None, diff --git a/tests/helpers/datamodules.py b/tests/helpers/datamodules.py index 12ec16261159d..2fc9f8a901f22 100644 --- a/tests/helpers/datamodules.py +++ b/tests/helpers/datamodules.py @@ -24,6 +24,10 @@ if _SKLEARN_AVAILABLE: from sklearn.datasets import make_classification, make_regression from sklearn.model_selection import train_test_split +else: + make_classification = None + make_regression = None + train_test_split = None class MNISTDataModule(LightningDataModule): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 789959e38908a..6987977e46e0a 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -256,9 +256,12 @@ class HookedModel(BoringModel): def __init__(self, called): super().__init__() pl_module_hooks = get_members(LightningModule) + # remove non-hooks + pl_module_hooks.difference_update({'optimizers'}) # remove most `nn.Module` hooks module_hooks = get_members(torch.nn.Module) - pl_module_hooks.difference_update(module_hooks - {'forward', 'zero_grad', 'train'}) + module_hooks.difference_update({'forward', 'zero_grad', 'train'}) + pl_module_hooks.difference_update(module_hooks) def call(hook, fn, *args, **kwargs): out = fn(*args, **kwargs) @@ -286,9 +289,15 @@ def test_epoch_end(self, *args, **kwargs): # `BoringModel` does not have a return for `test_step_end` so this would fail pass + def _train_batch(self, *args, **kwargs): + if self.automatic_optimization: + return self._auto_train_batch(*args, **kwargs) + return self._manual_train_batch(*args, **kwargs) + @staticmethod - def _train_batch(trainer, model, batches, device=torch.device('cpu'), current_epoch=0, **kwargs): + def _auto_train_batch(trainer, model, batches, device=torch.device('cpu'), current_epoch=0, **kwargs): using_native_amp = kwargs.get('amp_backend') == 'native' + using_deepspeed = kwargs.get('plugins') == 'deepspeed' out = [] for i in range(batches): out.extend([ @@ -299,18 +308,19 @@ def _train_batch(trainer, model, batches, device=torch.device('cpu'), current_ep dict(name='Callback.on_batch_start', args=(trainer, model)), dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)), dict(name='on_train_batch_start', args=(ANY, i, 0)), + # TODO: `on_before_optimizer_step` dict(name='forward', args=(ANY, )), dict(name='training_step', args=(ANY, i)), dict(name='training_step_end', args=(dict(loss=ANY), )), dict(name='Callback.on_before_zero_grad', args=(trainer, model, ANY)), dict(name='on_before_zero_grad', args=(ANY, )), dict(name='optimizer_zero_grad', args=(current_epoch, i, ANY, 0)), - # TODO: `on_before_backward` + dict(name='Callback.on_before_backward', args=(trainer, model, ANY)), + dict(name='on_before_backward', args=(ANY, )), # DeepSpeed handles backward internally - *([dict(name='backward', args=(ANY, ANY, 0))] if kwargs.get('plugins') != 'deepspeed' else []), + *([dict(name='backward', args=(ANY, ANY, 0))] if not using_deepspeed else []), dict(name='Callback.on_after_backward', args=(trainer, model)), dict(name='on_after_backward'), - # TODO: `on_before_optimizer_step` dict( name='optimizer_step', args=(current_epoch, i, ANY, 0, ANY), @@ -322,6 +332,37 @@ def _train_batch(trainer, model, batches, device=torch.device('cpu'), current_ep ]) return out + @staticmethod + def _manual_train_batch(trainer, model, batches, device=torch.device('cpu'), **kwargs): + using_deepspeed = kwargs.get('plugins') == 'deepspeed' + out = [] + for i in range(batches): + out.extend([ + dict(name='on_before_batch_transfer', args=(ANY, 0)), + dict(name='transfer_batch_to_device', args=(ANY, device, 0)), + dict(name='on_after_batch_transfer', args=(ANY, 0)), + # TODO: `on_batch_{start,end}` + dict(name='Callback.on_batch_start', args=(trainer, model)), + dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)), + dict(name='on_train_batch_start', args=(ANY, i, 0)), + dict(name='forward', args=(ANY, )), + dict(name='Callback.on_before_backward', args=(trainer, model, ANY)), + dict(name='on_before_backward', args=(ANY, )), + # DeepSpeed handles backward internally + *([dict(name='backward', args=(ANY, None, None))] if not using_deepspeed else []), + dict(name='Callback.on_after_backward', args=(trainer, model)), + dict(name='on_after_backward'), + # `manual_backward` calls the previous 3 + dict(name='manual_backward', args=(ANY, )), + # TODO: `on_before_optimizer_step` + dict(name='training_step', args=(ANY, i)), + dict(name='training_step_end', args=(dict(loss=ANY), )), + dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)), + dict(name='on_train_batch_end', args=(dict(loss=ANY), ANY, i, 0)), + dict(name='Callback.on_batch_end', args=(trainer, model)), + ]) + return out + @staticmethod def _eval_epoch(fn, trainer, model, batches, key, device=torch.device('cpu')): outputs = {key: ANY} @@ -388,9 +429,27 @@ def _predict_batch(trainer, model, batches): pytest.param(dict(gpus=1, precision=16, amp_backend='apex'), marks=RunIf(amp_apex=True, min_gpus=1)), ] ) -def test_trainer_model_hook_system_fit(tmpdir, kwargs): +@pytest.mark.parametrize('automatic_optimization', (True, False)) +def test_trainer_model_hook_system_fit(tmpdir, kwargs, automatic_optimization): called = [] - model = HookedModel(called) + + class TestModel(HookedModel): + + def __init__(self, *args): + super().__init__(*args) + self.automatic_optimization = automatic_optimization + + def training_step(self, batch, batch_idx): + if self.automatic_optimization: + return super().training_step(batch, batch_idx) + loss = self.step(batch[0]) + opt = self.optimizers() + opt.zero_grad() + self.manual_backward(loss) + opt.step() + return {'loss': loss} + + model = TestModel(called) callback = HookedCallback(called) train_batches = 2 val_batches = 2 diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 68b0f2d9178a9..f3d89b54ae236 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -657,14 +657,19 @@ def _deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimiz class VerificationCallback(Callback): + def __init__(self): + self.on_train_batch_start_called = False + def on_train_batch_start( self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int ) -> None: deepspeed_engine = trainer.training_type_plugin.model assert trainer.global_step == deepspeed_engine.global_steps + self.on_train_batch_start_called = True model = ModelParallelClassificationModel() dm = ClassifDataModule() + verification_callback = VerificationCallback() trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=0, @@ -674,9 +679,10 @@ def on_train_batch_start( limit_val_batches=2, precision=16, accumulate_grad_batches=2, - callbacks=[VerificationCallback()] + callbacks=[verification_callback] ) trainer.fit(model, datamodule=dm) + assert verification_callback.on_train_batch_start_called @RunIf(min_gpus=2, deepspeed=True, special=True) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 592fde1569344..431ca12ecab4b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -32,6 +32,7 @@ def test_fx_validator(tmpdir): funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')]) callbacks_func = [ + 'on_before_backward', 'on_after_backward', 'on_batch_end', 'on_batch_start',