diff --git a/CHANGELOG.md b/CHANGELOG.md index d8745778478d4..171aa8e3944a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -192,6 +192,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068)) + +- Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148)) + ### Removed - Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507)) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 62ec92dad51b2..a1002bfd55621 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -13,27 +13,376 @@ # limitations under the License. from abc import ABC from copy import deepcopy -from typing import Any, Dict, List, Type, Union +from typing import Any, Dict, List, Optional, Type, Union +import torch from packaging.version import Version import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import STEP_OUTPUT class TrainerCallbackHookMixin(ABC): + r""" + .. deprecated:: v1.6 + The `TrainerCallbackHookMixin` class was deprecated in v1.6 and will be removed in v1.8. + """ # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class callbacks: List[Callback] = [] lightning_module: "pl.LightningModule" + def on_before_accelerator_backend_setup(self) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_before_accelerator_backend_setup` was deprecated in v1.6 + and will be removed in v1.8. + + Called at the beginning of fit (train + validate), validate, test, or predict, or tune. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_before_accelerator_backend_setup` was deprecated in v1.6 " + "and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_before_accelerator_backend_setup(self, self.lightning_module) + + def on_configure_sharded_model(self) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_configure_sharded_model` was deprecated in v1.6 and will be removed in v1.8. + + Called at the beginning of fit (train + validate), validate, test, or predict, or tune. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_configure_sharded_model` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_configure_sharded_model(self, self.lightning_module) + + def setup(self, stage: Optional[str]) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.setup` was deprecated in v1.6 and will be removed in v1.8. + + Called at the beginning of fit (train + validate), validate, test, or predict, or tune. + """ + rank_zero_deprecation("`TrainerCallbackHookMixin.setup` was deprecated in v1.6 and will be removed in v1.8.") + for callback in self.callbacks: + callback.setup(self, self.lightning_module, stage=stage) + + def teardown(self, stage: Optional[str] = None) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.teardown` was deprecated in v1.6 and will be removed in v1.8. + + Called at the end of fit (train + validate), validate, test, or predict, or tune. + """ + rank_zero_deprecation("`TrainerCallbackHookMixin.teardown` was deprecated in v1.6 and will be removed in v1.8.") + for callback in self.callbacks: + callback.teardown(self, self.lightning_module, stage=stage) + + def on_init_start(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_init_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the trainer initialization begins, model has not yet been set. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_init_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_init_start(self) + + def on_init_end(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_init_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the trainer initialization ends, model has not yet been set. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_init_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_init_end(self) + + def on_fit_start(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_fit_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the trainer initialization begins, model has not yet been set. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_fit_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_fit_start(self, self.lightning_module) + + def on_fit_end(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_fit_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the trainer initialization begins, model has not yet been set. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_fit_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_fit_end(self, self.lightning_module) + + def on_sanity_check_start(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_sanity_check_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the validation sanity check starts. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_sanity_check_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_sanity_check_start(self, self.lightning_module) + + def on_sanity_check_end(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_sanity_check_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the validation sanity check ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_sanity_check_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_sanity_check_end(self, self.lightning_module) + + def on_train_epoch_start(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_train_epoch_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the epoch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_train_epoch_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_train_epoch_start(self, self.lightning_module) + + def on_train_epoch_end(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_train_epoch_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the epoch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_train_epoch_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_train_epoch_end(self, self.lightning_module) + + def on_validation_epoch_start(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_validation_epoch_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the epoch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_validation_epoch_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_validation_epoch_start(self, self.lightning_module) + + def on_validation_epoch_end(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_validation_epoch_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the validation epoch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_validation_epoch_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_validation_epoch_end(self, self.lightning_module) + + def on_test_epoch_start(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_test_epoch_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the epoch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_test_epoch_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_test_epoch_start(self, self.lightning_module) + + def on_test_epoch_end(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_test_epoch_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the test epoch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_test_epoch_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_test_epoch_end(self, self.lightning_module) + + def on_predict_epoch_start(self) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_predict_epoch_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the epoch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_predict_epoch_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_predict_epoch_start(self, self.lightning_module) + + def on_predict_epoch_end(self, outputs: List[Any]) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_predict_epoch_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the epoch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_predict_epoch_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_predict_epoch_end(self, self.lightning_module, outputs) + + def on_epoch_start(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_epoch_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when either of train/val/test epoch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_epoch_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_epoch_start(self, self.lightning_module) + + def on_epoch_end(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_epoch_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when either of train/val/test epoch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_epoch_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_epoch_end(self, self.lightning_module) + + def on_train_start(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_train_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the train begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_train_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_train_start(self, self.lightning_module) + + def on_train_end(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_train_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the train ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_train_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_train_end(self, self.lightning_module) + + def on_pretrain_routine_start(self) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_pretrain_routine_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the pre-train routine begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_pretrain_routine_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_pretrain_routine_start(self, self.lightning_module) + + def on_pretrain_routine_end(self) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_pretrain_routine_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the pre-train routine ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_pretrain_routine_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_pretrain_routine_end(self, self.lightning_module) + + def on_batch_start(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_batch_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the training batch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_batch_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_batch_start(self, self.lightning_module) + + def on_batch_end(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_batch_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the training batch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_batch_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_batch_end(self, self.lightning_module) + # TODO: Update this in v1.7 (deprecation: #9816) def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): - """Called when the training batch begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the training batch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8." + ) for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True): callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0) @@ -42,15 +391,210 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): # TODO: Update this in v1.7 (deprecation: #9816) def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx=0): - """Called when the training batch ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the training batch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.8." + ) for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_batch_end, "dataloader_idx", explicit=True): callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, 0) else: callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx) + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_validation_batch_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the validation batch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_validation_batch_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_validation_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) + + def on_validation_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_validation_batch_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the validation batch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_validation_batch_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_validation_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) + + def on_test_batch_start(self, batch, batch_idx, dataloader_idx): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the test batch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_test_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) + + def on_test_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_test_batch_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the test batch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_test_batch_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_test_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) + + def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_predict_batch_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the predict batch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_predict_batch_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_predict_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) + + def on_predict_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_predict_batch_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the predict batch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_predict_batch_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_predict_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) + + def on_validation_start(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_validation_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the validation loop begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_validation_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_validation_start(self, self.lightning_module) + + def on_validation_end(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_validation_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the validation loop ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_validation_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_validation_end(self, self.lightning_module) + + def on_test_start(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_test_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when the test begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_test_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_test_start(self, self.lightning_module) + + def on_test_end(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_test_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when the test ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_test_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_test_end(self, self.lightning_module) + + def on_predict_start(self) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_predict_start` was deprecated in v1.6 and will be removed in v1.8. + + Called when predict begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_predict_start` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_predict_start(self, self.lightning_module) + + def on_predict_end(self) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_predict_end` was deprecated in v1.6 and will be removed in v1.8. + + Called when predict ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_predict_end` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_predict_end(self, self.lightning_module) + + def on_keyboard_interrupt(self): + r""" + .. deprecated:: v1.5 + This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7. + + Called when any trainer execution is interrupted by KeyboardInterrupt. + """ + for callback in self.callbacks: + callback.on_keyboard_interrupt(self, self.lightning_module) + + def on_exception(self, exception: BaseException) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_exception` was deprecated in v1.6 and will be removed in v1.8. + + Called when any trainer execution is interrupted by an exception. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_exception` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_exception(self, self.lightning_module, exception) + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: - """Called when saving a model checkpoint.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_save_checkpoint` was deprecated in v1.6 and will be removed in v1.8. + + Called when saving a model checkpoint. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_save_checkpoint` was deprecated in v1.6 and will be removed in v1.8." + ) callback_states = {} for callback in self.callbacks: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) @@ -59,10 +603,18 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: return callback_states def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """Called when loading a model checkpoint.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.8. + + Called when loading a model checkpoint. + """ # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.8." + ) callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") if callback_states is None: @@ -83,3 +635,55 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if state: state = deepcopy(state) callback.on_load_checkpoint(self, self.lightning_module, state) + + def on_before_backward(self, loss: torch.Tensor) -> None: + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_before_backward` was deprecated in v1.6 and will be removed in v1.8. + + Called before ``loss.backward()``. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_before_backward` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_before_backward(self, self.lightning_module, loss) + + def on_after_backward(self): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_after_backward` was deprecated in v1.6 and will be removed in v1.8. + + Called after loss.backward() and before optimizers do anything. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_after_backward` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_after_backward(self, self.lightning_module) + + def on_before_optimizer_step(self, optimizer, optimizer_idx): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.8. + + Called after on_after_backward() once the gradient is accumulated and before optimizer.step(). + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_before_optimizer_step(self, self.lightning_module, optimizer, optimizer_idx) + + def on_before_zero_grad(self, optimizer): + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_before_zero_grad` was deprecated in v1.6 and will be removed in v1.8. + + Called after optimizer.step() and before optimizer.zero_grad(). + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_before_zero_grad` was deprecated in v1.6 and will be removed in v1.8." + ) + for callback in self.callbacks: + callback.on_before_zero_grad(self, self.lightning_module, optimizer) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 36f2c68594316..6b6b3abed91c2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -101,12 +101,14 @@ from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import ( _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT, EVAL_DATALOADERS, LRSchedulerTypeUnion, + STEP_OUTPUT, TRAIN_DATALOADERS, ) from pytorch_lightning.utilities.warnings import PossibleUserWarning @@ -1552,11 +1554,12 @@ def _call_callback_hooks( pl_module._current_fx_name = hook_name # TODO: remove if block in v1.7 - if hook_name in ("on_train_batch_start", "on_train_batch_end"): - fn = getattr(self, hook_name) - if callable(fn): - with self.profiler.profile(hook_name): - fn(*args, **kwargs) + if hook_name == "on_train_batch_start": + with self.profiler.profile(hook_name): + self._on_train_batch_start(*args, **kwargs) + elif hook_name == "on_train_batch_end": + with self.profiler.profile(hook_name): + self._on_train_batch_end(*args, **kwargs) else: for callback in self.callbacks: fn = getattr(callback, hook_name) @@ -1568,6 +1571,28 @@ def _call_callback_hooks( # restore current_fx when nested context pl_module._current_fx_name = prev_fx_name + # TODO: Delete this in v1.7 (deprecations: #9816 and #11148) + def _on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): + r"""Called when the training batch begins. This function is needed because of two different deprecations affecting + the original function in TrainerCallbackHookMixin: #9816 and #11148. + """ + for callback in self.callbacks: + if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True): + callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0) + else: + callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx) + + # TODO: Delete this in v1.7 (deprecations: #9816 and #11148) + def _on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx=0): + r"""Called when the training batch ends. This function is needed because of two different deprecations affecting + the original function in TrainerCallbackHookMixin: #9816 and #11148. + """ + for callback in self.callbacks: + if is_param_in_hook_signature(callback.on_train_batch_end, "dataloader_idx", explicit=True): + callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, 0) + else: + callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx) + def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: """Called when saving a model checkpoint. diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 67ecb736ed6c4..b87a43b8f4939 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -16,6 +16,7 @@ import pytest import torch +from torch import optim from pytorch_lightning import Callback, Trainer from pytorch_lightning.utilities import rank_zero_warn @@ -137,3 +138,97 @@ def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir): match=r"`Trainer.should_rank_save_checkpoint` is deprecated in v1.6 and will be removed in 1.8." ): _ = trainer.should_rank_save_checkpoint + + +def test_v1_8_0_deprecate_trainer_callback_hook_mixin(): + methods_with_self = [ + "on_before_accelerator_backend_setup", + "on_configure_sharded_model", + "on_init_start", + "on_init_end", + "on_fit_start", + "on_fit_end", + "on_sanity_check_start", + "on_sanity_check_end", + "on_train_epoch_start", + "on_train_epoch_end", + "on_validation_epoch_start", + "on_validation_epoch_end", + "on_test_epoch_start", + "on_test_epoch_end", + "on_predict_epoch_start", + "on_epoch_start", + "on_epoch_end", + "on_train_start", + "on_train_end", + "on_pretrain_routine_start", + "on_pretrain_routine_end", + "on_batch_start", + "on_batch_end", + "on_validation_start", + "on_validation_end", + "on_test_start", + "on_test_end", + "on_predict_start", + "on_predict_end", + "on_after_backward", + ] + methods_with_stage = [ + "setup", + "teardown", + ] + methods_with_batch_batch_idx_dataloader_idx = [ + "on_train_batch_start", + "on_validation_batch_start", + "on_test_batch_start", + "on_predict_batch_start", + ] + methods_with_outputs_batch_batch_idx_dataloader_idx = [ + "on_train_batch_end", + "on_validation_batch_end", + "on_test_batch_end", + "on_predict_batch_end", + ] + methods_with_checkpoint = ["on_save_checkpoint", "on_load_checkpoint"] + trainer = Trainer( + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + enable_progress_bar=False, + logger=False, + ) + model = BoringModel() + # need to attach model to trainer for testing of `on_pretrain_routine_start` + trainer.fit(model) + for method_name in methods_with_self: + fn = getattr(trainer, method_name, None) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + fn() + for method_name in methods_with_stage: + fn = getattr(trainer, method_name) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + fn(stage="test") + for method_name in methods_with_batch_batch_idx_dataloader_idx: + fn = getattr(trainer, method_name) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + fn(batch={}, batch_idx=0, dataloader_idx=0) + for method_name in methods_with_outputs_batch_batch_idx_dataloader_idx: + fn = getattr(trainer, method_name) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + fn(outputs=torch.tensor([[1.0, -1.0], [1.0, -1.0]]), batch={}, batch_idx=0, dataloader_idx=0) + for method_name in methods_with_checkpoint: + fn = getattr(trainer, method_name) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + fn(checkpoint={}) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + trainer.on_predict_epoch_end(outputs=torch.tensor([[1.0, -1.0], [1.0, -1.0]])) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + trainer.on_exception(exception=Exception) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + trainer.on_before_backward(loss=torch.tensor([[1.0, -1.0], [1.0, -1.0]])) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + trainer.on_before_optimizer_step( + optimizer=optim.SGD(model.parameters(), lr=0.01, momentum=0.9), optimizer_idx=0 + ) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + trainer.on_before_zero_grad(optimizer=optim.SGD(model.parameters(), lr=0.01, momentum=0.9))