From 6b01c49985ae79e4fc00c973c8f0fa904cca82eb Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Fri, 17 Dec 2021 16:55:10 -0800 Subject: [PATCH 1/7] put back functions in TrainerCallbackHookMixin --- pytorch_lightning/trainer/callback_hook.py | 273 ++++++++++++++++++++- 1 file changed, 272 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 62ec92dad51b2..1cca7b9f5fcd5 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -13,8 +13,9 @@ # limitations under the License. from abc import ABC from copy import deepcopy -from typing import Any, Dict, List, Type, Union +from typing import Any, Dict, List, Optional, Type, Union +import torch from packaging.version import Version import pytorch_lightning as pl @@ -31,7 +32,164 @@ class TrainerCallbackHookMixin(ABC): callbacks: List[Callback] = [] lightning_module: "pl.LightningModule" + # TODO: Delete this in v1.8 (deprecation #10979) + def on_before_accelerator_backend_setup(self) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + for callback in self.callbacks: + callback.on_before_accelerator_backend_setup(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_configure_sharded_model(self) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + for callback in self.callbacks: + callback.on_configure_sharded_model(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def setup(self, stage: Optional[str]) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + for callback in self.callbacks: + callback.setup(self, self.lightning_module, stage=stage) + + # TODO: Delete this in v1.8 (deprecation #10979) + def teardown(self, stage: Optional[str] = None) -> None: + """Called at the end of fit (train + validate), validate, test, or predict, or tune.""" + for callback in self.callbacks: + callback.teardown(self, self.lightning_module, stage=stage) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_init_start(self): + """Called when the trainer initialization begins, model has not yet been set.""" + for callback in self.callbacks: + callback.on_init_start(self) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_init_end(self): + """Called when the trainer initialization ends, model has not yet been set.""" + for callback in self.callbacks: + callback.on_init_end(self) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_fit_start(self): + """Called when the trainer initialization begins, model has not yet been set.""" + for callback in self.callbacks: + callback.on_fit_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_fit_end(self): + """Called when the trainer initialization begins, model has not yet been set.""" + for callback in self.callbacks: + callback.on_fit_end(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_sanity_check_start(self): + """Called when the validation sanity check starts.""" + for callback in self.callbacks: + callback.on_sanity_check_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_sanity_check_end(self): + """Called when the validation sanity check ends.""" + for callback in self.callbacks: + callback.on_sanity_check_end(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_train_epoch_start(self): + """Called when the epoch begins.""" + for callback in self.callbacks: + callback.on_train_epoch_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_train_epoch_end(self): + """Called when the epoch ends.""" + for callback in self.callbacks: + callback.on_train_epoch_end(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_validation_epoch_start(self): + """Called when the epoch begins.""" + for callback in self.callbacks: + callback.on_validation_epoch_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_validation_epoch_end(self): + """Called when the validation epoch ends.""" + for callback in self.callbacks: + callback.on_validation_epoch_end(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_test_epoch_start(self): + """Called when the epoch begins.""" + for callback in self.callbacks: + callback.on_test_epoch_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_test_epoch_end(self): + """Called when the test epoch ends.""" + for callback in self.callbacks: + callback.on_test_epoch_end(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_predict_epoch_start(self) -> None: + """Called when the epoch begins.""" + for callback in self.callbacks: + callback.on_predict_epoch_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_predict_epoch_end(self, outputs: List[Any]) -> None: + """Called when the epoch ends.""" + for callback in self.callbacks: + callback.on_predict_epoch_end(self, self.lightning_module, outputs) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_epoch_start(self): + """Called when either of train/val/test epoch begins.""" + for callback in self.callbacks: + callback.on_epoch_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_epoch_end(self): + """Called when either of train/val/test epoch ends.""" + for callback in self.callbacks: + callback.on_epoch_end(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_train_start(self): + """Called when the train begins.""" + for callback in self.callbacks: + callback.on_train_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_train_end(self): + """Called when the train ends.""" + for callback in self.callbacks: + callback.on_train_end(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_pretrain_routine_start(self) -> None: + """Called when the pre-train routine begins.""" + for callback in self.callbacks: + callback.on_pretrain_routine_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_pretrain_routine_end(self) -> None: + """Called when the pre-train routine ends.""" + for callback in self.callbacks: + callback.on_pretrain_routine_end(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_batch_start(self): + """Called when the training batch begins.""" + for callback in self.callbacks: + callback.on_batch_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_batch_end(self): + """Called when the training batch ends.""" + for callback in self.callbacks: + callback.on_batch_end(self, self.lightning_module) + # TODO: Update this in v1.7 (deprecation: #9816) + # TODO: Delete this in v1.8 (deprecation #10979) def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): """Called when the training batch begins.""" for callback in self.callbacks: @@ -41,6 +199,7 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx) # TODO: Update this in v1.7 (deprecation: #9816) + # TODO: Delete this in v1.8 (deprecation #10979) def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx=0): """Called when the training batch ends.""" for callback in self.callbacks: @@ -49,6 +208,94 @@ def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_ else: callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx) + # TODO: Delete this in v1.8 (deprecation #10979) + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): + """Called when the validation batch begins.""" + for callback in self.callbacks: + callback.on_validation_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_validation_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx): + """Called when the validation batch ends.""" + for callback in self.callbacks: + callback.on_validation_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_test_batch_start(self, batch, batch_idx, dataloader_idx): + """Called when the test batch begins.""" + for callback in self.callbacks: + callback.on_test_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_test_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx): + """Called when the test batch ends.""" + for callback in self.callbacks: + callback.on_test_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """Called when the predict batch begins.""" + for callback in self.callbacks: + callback.on_predict_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_predict_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """Called when the predict batch ends.""" + for callback in self.callbacks: + callback.on_predict_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_validation_start(self): + """Called when the validation loop begins.""" + for callback in self.callbacks: + callback.on_validation_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_validation_end(self): + """Called when the validation loop ends.""" + for callback in self.callbacks: + callback.on_validation_end(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_test_start(self): + """Called when the test begins.""" + for callback in self.callbacks: + callback.on_test_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_test_end(self): + """Called when the test ends.""" + for callback in self.callbacks: + callback.on_test_end(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_predict_start(self) -> None: + """Called when predict begins.""" + for callback in self.callbacks: + callback.on_predict_start(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_predict_end(self) -> None: + """Called when predict ends.""" + for callback in self.callbacks: + callback.on_predict_end(self, self.lightning_module) + + def on_keyboard_interrupt(self): + r""" + .. deprecated:: v1.5 + This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7. + + Called when any trainer execution is interrupted by KeyboardInterrupt. + """ + for callback in self.callbacks: + callback.on_keyboard_interrupt(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_exception(self, exception: BaseException) -> None: + """Called when any trainer execution is interrupted by an exception.""" + for callback in self.callbacks: + callback.on_exception(self, self.lightning_module, exception) + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: """Called when saving a model checkpoint.""" callback_states = {} @@ -83,3 +330,27 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if state: state = deepcopy(state) callback.on_load_checkpoint(self, self.lightning_module, state) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_before_backward(self, loss: torch.Tensor) -> None: + """Called before ``loss.backward()``.""" + for callback in self.callbacks: + callback.on_before_backward(self, self.lightning_module, loss) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_after_backward(self): + """Called after loss.backward() and before optimizers do anything.""" + for callback in self.callbacks: + callback.on_after_backward(self, self.lightning_module) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_before_optimizer_step(self, optimizer, optimizer_idx): + """Called after on_after_backward() once the gradient is accumulated and before optimizer.step().""" + for callback in self.callbacks: + callback.on_before_optimizer_step(self, self.lightning_module, optimizer, optimizer_idx) + + # TODO: Delete this in v1.8 (deprecation #10979) + def on_before_zero_grad(self, optimizer): + """Called after optimizer.step() and before optimizer.zero_grad().""" + for callback in self.callbacks: + callback.on_before_zero_grad(self, self.lightning_module, optimizer) From e17308541d9059948f8ea88705d55c6a9edbb646 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Sat, 18 Dec 2021 17:08:25 -0800 Subject: [PATCH 2/7] add dep warnings --- pytorch_lightning/trainer/callback_hook.py | 495 +++++++++++++++++---- tests/deprecated_api/test_remove_1-7.py | 24 + 2 files changed, 428 insertions(+), 91 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 1cca7b9f5fcd5..9130ec4500c52 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -32,166 +32,353 @@ class TrainerCallbackHookMixin(ABC): callbacks: List[Callback] = [] lightning_module: "pl.LightningModule" - # TODO: Delete this in v1.8 (deprecation #10979) def on_before_accelerator_backend_setup(self) -> None: - """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_before_accelerator_backend_setup` was deprecated in v1.6 + and will be removed in v1.7. + + Called at the beginning of fit (train + validate), validate, test, or predict, or tune. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_before_accelerator_backend_setup` was deprecated in v1.6 " + "and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_configure_sharded_model(self) -> None: - """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_configure_sharded_model` was deprecated in v1.6 and will be removed in v1.7. + + Called at the beginning of fit (train + validate), validate, test, or predict, or tune. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_configure_sharded_model` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_configure_sharded_model(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def setup(self, stage: Optional[str]) -> None: - """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.setup` was deprecated in v1.6 and will be removed in v1.7. + + Called at the beginning of fit (train + validate), validate, test, or predict, or tune. + """ + rank_zero_deprecation("`TrainerCallbackHookMixin.setup` was deprecated in v1.6 and will be removed in v1.7.") for callback in self.callbacks: callback.setup(self, self.lightning_module, stage=stage) - # TODO: Delete this in v1.8 (deprecation #10979) def teardown(self, stage: Optional[str] = None) -> None: - """Called at the end of fit (train + validate), validate, test, or predict, or tune.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.teardown` was deprecated in v1.6 and will be removed in v1.7. + + Called at the end of fit (train + validate), validate, test, or predict, or tune. + """ + rank_zero_deprecation("`TrainerCallbackHookMixin.teardown` was deprecated in v1.6 and will be removed in v1.7.") for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage=stage) - # TODO: Delete this in v1.8 (deprecation #10979) def on_init_start(self): - """Called when the trainer initialization begins, model has not yet been set.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_init_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the trainer initialization begins, model has not yet been set. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_init_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_init_start(self) - # TODO: Delete this in v1.8 (deprecation #10979) def on_init_end(self): - """Called when the trainer initialization ends, model has not yet been set.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_init_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the trainer initialization ends, model has not yet been set. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_init_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_init_end(self) - # TODO: Delete this in v1.8 (deprecation #10979) def on_fit_start(self): - """Called when the trainer initialization begins, model has not yet been set.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_fit_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the trainer initialization begins, model has not yet been set. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_fit_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_fit_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_fit_end(self): - """Called when the trainer initialization begins, model has not yet been set.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_fit_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the trainer initialization begins, model has not yet been set. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_fit_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_fit_end(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_sanity_check_start(self): - """Called when the validation sanity check starts.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_sanity_check_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the validation sanity check starts. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_sanity_check_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_sanity_check_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_sanity_check_end(self): - """Called when the validation sanity check ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_sanity_check_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the validation sanity check ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_sanity_check_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_sanity_check_end(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_train_epoch_start(self): - """Called when the epoch begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_train_epoch_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the epoch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_train_epoch_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_train_epoch_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_train_epoch_end(self): - """Called when the epoch ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_train_epoch_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the epoch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_train_epoch_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_validation_epoch_start(self): - """Called when the epoch begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_validation_epoch_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the epoch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_validation_epoch_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_validation_epoch_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_validation_epoch_end(self): - """Called when the validation epoch ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_validation_epoch_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the validation epoch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_validation_epoch_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_validation_epoch_end(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_test_epoch_start(self): - """Called when the epoch begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_test_epoch_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the epoch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_test_epoch_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_test_epoch_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_test_epoch_end(self): - """Called when the test epoch ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_test_epoch_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the test epoch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_test_epoch_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_test_epoch_end(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_predict_epoch_start(self) -> None: - """Called when the epoch begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_predict_epoch_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the epoch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_predict_epoch_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_predict_epoch_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_predict_epoch_end(self, outputs: List[Any]) -> None: - """Called when the epoch ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_predict_epoch_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the epoch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_predict_epoch_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_predict_epoch_end(self, self.lightning_module, outputs) - # TODO: Delete this in v1.8 (deprecation #10979) def on_epoch_start(self): - """Called when either of train/val/test epoch begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_epoch_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when either of train/val/test epoch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_epoch_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_epoch_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_epoch_end(self): - """Called when either of train/val/test epoch ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_epoch_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when either of train/val/test epoch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_epoch_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_epoch_end(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_train_start(self): - """Called when the train begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_train_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the train begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_train_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_train_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_train_end(self): - """Called when the train ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_train_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the train ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_train_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_train_end(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_pretrain_routine_start(self) -> None: - """Called when the pre-train routine begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_pretrain_routine_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the pre-train routine begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_pretrain_routine_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_pretrain_routine_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_pretrain_routine_end(self) -> None: - """Called when the pre-train routine ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_pretrain_routine_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the pre-train routine ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_pretrain_routine_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_pretrain_routine_end(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_batch_start(self): - """Called when the training batch begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_batch_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the training batch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_batch_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_batch_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_batch_end(self): - """Called when the training batch ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_batch_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the training batch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_batch_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_batch_end(self, self.lightning_module) # TODO: Update this in v1.7 (deprecation: #9816) - # TODO: Delete this in v1.8 (deprecation #10979) def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): - """Called when the training batch begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the training batch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True): callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0) @@ -199,84 +386,175 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx) # TODO: Update this in v1.7 (deprecation: #9816) - # TODO: Delete this in v1.8 (deprecation #10979) def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx=0): - """Called when the training batch ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the training batch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_batch_end, "dataloader_idx", explicit=True): callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, 0) else: callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx) - # TODO: Delete this in v1.8 (deprecation #10979) def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): - """Called when the validation batch begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_validation_batch_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the validation batch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_validation_batch_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_validation_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) - # TODO: Delete this in v1.8 (deprecation #10979) def on_validation_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx): - """Called when the validation batch ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_validation_batch_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the validation batch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_validation_batch_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_validation_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) - # TODO: Delete this in v1.8 (deprecation #10979) def on_test_batch_start(self, batch, batch_idx, dataloader_idx): - """Called when the test batch begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the test batch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_test_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) - # TODO: Delete this in v1.8 (deprecation #10979) def on_test_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx): - """Called when the test batch ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_test_batch_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the test batch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_test_batch_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_test_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) - # TODO: Delete this in v1.8 (deprecation #10979) def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - """Called when the predict batch begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_predict_batch_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the predict batch begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_predict_batch_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_predict_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) - # TODO: Delete this in v1.8 (deprecation #10979) def on_predict_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - """Called when the predict batch ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_predict_batch_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the predict batch ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_predict_batch_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_predict_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) - # TODO: Delete this in v1.8 (deprecation #10979) def on_validation_start(self): - """Called when the validation loop begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_validation_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the validation loop begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_validation_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_validation_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_validation_end(self): - """Called when the validation loop ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_validation_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the validation loop ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_validation_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_validation_end(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_test_start(self): - """Called when the test begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_test_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when the test begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_test_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_test_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_test_end(self): - """Called when the test ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_test_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when the test ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_test_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_test_end(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_predict_start(self) -> None: - """Called when predict begins.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_predict_start` was deprecated in v1.6 and will be removed in v1.7. + + Called when predict begins. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_predict_start` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_predict_start(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_predict_end(self) -> None: - """Called when predict ends.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_predict_end` was deprecated in v1.6 and will be removed in v1.7. + + Called when predict ends. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_predict_end` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_predict_end(self, self.lightning_module) @@ -290,9 +568,16 @@ def on_keyboard_interrupt(self): for callback in self.callbacks: callback.on_keyboard_interrupt(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_exception(self, exception: BaseException) -> None: - """Called when any trainer execution is interrupted by an exception.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_exception` was deprecated in v1.6 and will be removed in v1.7. + + Called when any trainer execution is interrupted by an exception. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_exception` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_exception(self, self.lightning_module, exception) @@ -331,26 +616,54 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: state = deepcopy(state) callback.on_load_checkpoint(self, self.lightning_module, state) - # TODO: Delete this in v1.8 (deprecation #10979) def on_before_backward(self, loss: torch.Tensor) -> None: - """Called before ``loss.backward()``.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_before_backward` was deprecated in v1.6 and will be removed in v1.7. + + Called before ``loss.backward()``. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_before_backward` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_before_backward(self, self.lightning_module, loss) - # TODO: Delete this in v1.8 (deprecation #10979) def on_after_backward(self): - """Called after loss.backward() and before optimizers do anything.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_after_backward` was deprecated in v1.6 and will be removed in v1.7. + + Called after loss.backward() and before optimizers do anything. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_after_backward` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_after_backward(self, self.lightning_module) - # TODO: Delete this in v1.8 (deprecation #10979) def on_before_optimizer_step(self, optimizer, optimizer_idx): - """Called after on_after_backward() once the gradient is accumulated and before optimizer.step().""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.7. + + Called after on_after_backward() once the gradient is accumulated and before optimizer.step(). + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_before_optimizer_step(self, self.lightning_module, optimizer, optimizer_idx) - # TODO: Delete this in v1.8 (deprecation #10979) def on_before_zero_grad(self, optimizer): - """Called after optimizer.step() and before optimizer.zero_grad().""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_before_zero_grad` was deprecated in v1.6 and will be removed in v1.7. + + Called after optimizer.step() and before optimizer.zero_grad(). + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_before_zero_grad` was deprecated in v1.6 and will be removed in v1.7." + ) for callback in self.callbacks: callback.on_before_zero_grad(self, self.lightning_module, optimizer) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 986491a306ce7..5229858dfa8f3 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -551,3 +551,27 @@ def post_dispatch(self, trainer): with pytest.deprecated_call(match=escape("`CustomPlugin.post_dispatch()` has been deprecated in v1.6")): CustomPlugin(torch.device("cpu")) + + +def test_v1_7_0_deprecate_TrainerCallbackHookMixin(): + trainer = Trainer( + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + enable_progress_bar=False, + logger=False, + ) + with pytest.deprecated_call( + match="`TrainerCallbackHookMixin.on_after_backward` was deprecated in v1.6 and will be removed in v1.7" + ): + trainer.call_hook("on_after_backward") + + with pytest.deprecated_call( + match="`TrainerCallbackHookMixin.on_exception` was deprecated in v1.6 and will be removed in v1.7" + ): + trainer.call_hook("on_exception", Exception) + + with pytest.deprecated_call( + match="`TrainerCallbackHookMixin.on_predict_start` was deprecated in v1.6 and will be removed in v1.7" + ): + trainer.call_hook("on_predict_start") From b1f884fed40ec2e2673dd10719bbc85bfa355893 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Sat, 18 Dec 2021 17:11:54 -0800 Subject: [PATCH 3/7] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9990bff5f2153..f8fc4eabd52f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068)) + +- Deprecated most of the callback_hook functions in `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148)) + ### Removed - Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507)) From dde1069f53e21202f1a187ed4bdf86c6bca70ca0 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Mon, 20 Dec 2021 10:05:55 -0800 Subject: [PATCH 4/7] addr comments --- CHANGELOG.md | 2 +- pytorch_lightning/trainer/callback_hook.py | 24 +++++- tests/deprecated_api/test_remove_1-7.py | 97 ++++++++++++++++++---- 3 files changed, 105 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 15e5c4abe26a5..924367fde8e1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -179,7 +179,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068)) -- Deprecated most of the callback_hook functions in `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148)) +- Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148)) ### Removed diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 9130ec4500c52..8df30549059c5 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -26,6 +26,10 @@ class TrainerCallbackHookMixin(ABC): + r""" + .. deprecated:: v1.6 + The `TrainerCallbackHookMixin` class was deprecated in v1.6 and will be removed in v1.7. + """ # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class @@ -582,7 +586,15 @@ def on_exception(self, exception: BaseException) -> None: callback.on_exception(self, self.lightning_module, exception) def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: - """Called when saving a model checkpoint.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_save_checkpoint` was deprecated in v1.6 and will be removed in v1.7. + + Called when saving a model checkpoint. + """ + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_save_checkpoint` was deprecated in v1.6 and will be removed in v1.7." + ) callback_states = {} for callback in self.callbacks: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) @@ -591,10 +603,18 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: return callback_states def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """Called when loading a model checkpoint.""" + r""" + .. deprecated:: v1.6 + `TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.7. + + Called when loading a model checkpoint. + """ # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 + rank_zero_deprecation( + "`TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.7." + ) callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") if callback_states is None: diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 5229858dfa8f3..1b1f0a401819c 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -553,7 +553,57 @@ def post_dispatch(self, trainer): CustomPlugin(torch.device("cpu")) -def test_v1_7_0_deprecate_TrainerCallbackHookMixin(): +def test_v1_7_0_deprecate_trainer_callback_hook_mixin(): + methods_with_self = [ + "on_before_accelerator_backend_setup", + "on_configure_sharded_model", + "on_init_start", + "on_init_end", + "on_fit_start", + "on_fit_end", + "on_sanity_check_start", + "on_sanity_check_end", + "on_train_epoch_start", + "on_train_epoch_end", + "on_validation_epoch_start", + "on_validation_epoch_end", + "on_test_epoch_start", + "on_test_epoch_end", + "on_predict_epoch_start", + "on_epoch_start", + "on_epoch_end", + "on_train_start", + "on_train_end", + #"on_pretrain_routine_start", + "on_pretrain_routine_end", + "on_batch_start", + "on_batch_end", + "on_validation_start", + "on_validation_end", + "on_test_start", + "on_test_end", + "on_predict_start", + "on_predict_end", + "on_keyboard_interrupt", + "on_after_backward", + ] + methods_with_stage = [ + "setup", + "teardown", + ] + methods_with_batch_batch_idx_dataloader_idx = [ + "on_train_batch_start", + "on_validation_batch_start", + "on_test_batch_start", + "on_predict_batch_start", + ] + methods_with_outputs_batch_batch_idx_dataloader_idx = [ + "on_train_batch_end", + "on_validation_batch_end", + "on_test_batch_end", + "on_predict_batch_end", + ] + methods_with_checkpoint = ["on_save_checkpoint", "on_load_checkpoint"] trainer = Trainer( max_epochs=1, limit_val_batches=0.1, @@ -561,17 +611,34 @@ def test_v1_7_0_deprecate_TrainerCallbackHookMixin(): enable_progress_bar=False, logger=False, ) - with pytest.deprecated_call( - match="`TrainerCallbackHookMixin.on_after_backward` was deprecated in v1.6 and will be removed in v1.7" - ): - trainer.call_hook("on_after_backward") - - with pytest.deprecated_call( - match="`TrainerCallbackHookMixin.on_exception` was deprecated in v1.6 and will be removed in v1.7" - ): - trainer.call_hook("on_exception", Exception) - - with pytest.deprecated_call( - match="`TrainerCallbackHookMixin.on_predict_start` was deprecated in v1.6 and will be removed in v1.7" - ): - trainer.call_hook("on_predict_start") + model = BoringModel() + for method_name in methods_with_self: + fn = getattr(trainer, method_name) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): + fn() + # for method_name in methods_with_stage: + # fn = getattr(trainer, method_name) + # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): + # fn(stage="test") + # for method_name in methods_with_batch_batch_idx_dataloader_idx: + # fn = getattr(trainer, method_name) + # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): + # fn(batch, batch_idx=0, dataloader_idx=0) + # for method_name in methods_with_outputs_batch_batch_idx_dataloader_idx: + # fn = getattr(trainer, method_name) + # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): + # fn(outputs, batch, batch_idx=0, dataloader_idx=0) + # for method_name in methods_with_outputs_batch_batch_idx_dataloader_idx: + # fn = getattr(trainer, method_name) + # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): + # fn(checkpoint) + # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): + # trainer.on_predict_epoch_end(outputs) + # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): + # trainer.on_exception(Exception) + # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): + # trainer.on_before_backward(loss=torch.tensor([[1.0, -1.0], [1.0, -1.0]])) + # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): + # trainer.on_before_optimizer_step(optimizer, optimizer_idx) + # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): + # trainer.on_before_zero_grad(optimizer) From 59b7e53f145575d0d1e42c0133d6736d687194df Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Dec 2021 20:50:15 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/deprecated_api/test_remove_1-7.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 1b1f0a401819c..23d20d8470e26 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -574,7 +574,7 @@ def test_v1_7_0_deprecate_trainer_callback_hook_mixin(): "on_epoch_end", "on_train_start", "on_train_end", - #"on_pretrain_routine_start", + # "on_pretrain_routine_start", "on_pretrain_routine_end", "on_batch_start", "on_batch_end", From c9c8d1424bf4313f5f5425472e531905e7e891fb Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Mon, 20 Dec 2021 12:51:31 -0800 Subject: [PATCH 6/7] addr comments --- pytorch_lightning/trainer/callback_hook.py | 190 ++++++++++----------- pytorch_lightning/trainer/trainer.py | 33 +++- tests/deprecated_api/test_remove_1-7.py | 91 ---------- tests/deprecated_api/test_remove_1-8.py | 95 +++++++++++ 4 files changed, 218 insertions(+), 191 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8df30549059c5..a1002bfd55621 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -28,7 +28,7 @@ class TrainerCallbackHookMixin(ABC): r""" .. deprecated:: v1.6 - The `TrainerCallbackHookMixin` class was deprecated in v1.6 and will be removed in v1.7. + The `TrainerCallbackHookMixin` class was deprecated in v1.6 and will be removed in v1.8. """ # this is just a summary on variables used in this abstract class, @@ -40,13 +40,13 @@ def on_before_accelerator_backend_setup(self) -> None: r""" .. deprecated:: v1.6 `TrainerCallbackHookMixin.on_before_accelerator_backend_setup` was deprecated in v1.6 - and will be removed in v1.7. + and will be removed in v1.8. Called at the beginning of fit (train + validate), validate, test, or predict, or tune. """ rank_zero_deprecation( "`TrainerCallbackHookMixin.on_before_accelerator_backend_setup` was deprecated in v1.6 " - "and will be removed in v1.7." + "and will be removed in v1.8." ) for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, self.lightning_module) @@ -54,12 +54,12 @@ def on_before_accelerator_backend_setup(self) -> None: def on_configure_sharded_model(self) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_configure_sharded_model` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_configure_sharded_model` was deprecated in v1.6 and will be removed in v1.8. Called at the beginning of fit (train + validate), validate, test, or predict, or tune. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_configure_sharded_model` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_configure_sharded_model` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_configure_sharded_model(self, self.lightning_module) @@ -67,34 +67,34 @@ def on_configure_sharded_model(self) -> None: def setup(self, stage: Optional[str]) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.setup` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.setup` was deprecated in v1.6 and will be removed in v1.8. Called at the beginning of fit (train + validate), validate, test, or predict, or tune. """ - rank_zero_deprecation("`TrainerCallbackHookMixin.setup` was deprecated in v1.6 and will be removed in v1.7.") + rank_zero_deprecation("`TrainerCallbackHookMixin.setup` was deprecated in v1.6 and will be removed in v1.8.") for callback in self.callbacks: callback.setup(self, self.lightning_module, stage=stage) def teardown(self, stage: Optional[str] = None) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.teardown` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.teardown` was deprecated in v1.6 and will be removed in v1.8. Called at the end of fit (train + validate), validate, test, or predict, or tune. """ - rank_zero_deprecation("`TrainerCallbackHookMixin.teardown` was deprecated in v1.6 and will be removed in v1.7.") + rank_zero_deprecation("`TrainerCallbackHookMixin.teardown` was deprecated in v1.6 and will be removed in v1.8.") for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage=stage) def on_init_start(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_init_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_init_start` was deprecated in v1.6 and will be removed in v1.8. Called when the trainer initialization begins, model has not yet been set. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_init_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_init_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_init_start(self) @@ -102,12 +102,12 @@ def on_init_start(self): def on_init_end(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_init_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_init_end` was deprecated in v1.6 and will be removed in v1.8. Called when the trainer initialization ends, model has not yet been set. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_init_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_init_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_init_end(self) @@ -115,12 +115,12 @@ def on_init_end(self): def on_fit_start(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_fit_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_fit_start` was deprecated in v1.6 and will be removed in v1.8. Called when the trainer initialization begins, model has not yet been set. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_fit_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_fit_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_fit_start(self, self.lightning_module) @@ -128,12 +128,12 @@ def on_fit_start(self): def on_fit_end(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_fit_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_fit_end` was deprecated in v1.6 and will be removed in v1.8. Called when the trainer initialization begins, model has not yet been set. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_fit_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_fit_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_fit_end(self, self.lightning_module) @@ -141,12 +141,12 @@ def on_fit_end(self): def on_sanity_check_start(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_sanity_check_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_sanity_check_start` was deprecated in v1.6 and will be removed in v1.8. Called when the validation sanity check starts. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_sanity_check_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_sanity_check_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_sanity_check_start(self, self.lightning_module) @@ -154,12 +154,12 @@ def on_sanity_check_start(self): def on_sanity_check_end(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_sanity_check_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_sanity_check_end` was deprecated in v1.6 and will be removed in v1.8. Called when the validation sanity check ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_sanity_check_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_sanity_check_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_sanity_check_end(self, self.lightning_module) @@ -167,12 +167,12 @@ def on_sanity_check_end(self): def on_train_epoch_start(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_train_epoch_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_train_epoch_start` was deprecated in v1.6 and will be removed in v1.8. Called when the epoch begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_train_epoch_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_train_epoch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_train_epoch_start(self, self.lightning_module) @@ -180,12 +180,12 @@ def on_train_epoch_start(self): def on_train_epoch_end(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_train_epoch_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_train_epoch_end` was deprecated in v1.6 and will be removed in v1.8. Called when the epoch ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_train_epoch_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_train_epoch_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module) @@ -193,12 +193,12 @@ def on_train_epoch_end(self): def on_validation_epoch_start(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_validation_epoch_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_validation_epoch_start` was deprecated in v1.6 and will be removed in v1.8. Called when the epoch begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_validation_epoch_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_validation_epoch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_validation_epoch_start(self, self.lightning_module) @@ -206,12 +206,12 @@ def on_validation_epoch_start(self): def on_validation_epoch_end(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_validation_epoch_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_validation_epoch_end` was deprecated in v1.6 and will be removed in v1.8. Called when the validation epoch ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_validation_epoch_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_validation_epoch_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_validation_epoch_end(self, self.lightning_module) @@ -219,12 +219,12 @@ def on_validation_epoch_end(self): def on_test_epoch_start(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_test_epoch_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_test_epoch_start` was deprecated in v1.6 and will be removed in v1.8. Called when the epoch begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_test_epoch_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_test_epoch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_test_epoch_start(self, self.lightning_module) @@ -232,12 +232,12 @@ def on_test_epoch_start(self): def on_test_epoch_end(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_test_epoch_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_test_epoch_end` was deprecated in v1.6 and will be removed in v1.8. Called when the test epoch ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_test_epoch_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_test_epoch_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_test_epoch_end(self, self.lightning_module) @@ -245,12 +245,12 @@ def on_test_epoch_end(self): def on_predict_epoch_start(self) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_predict_epoch_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_predict_epoch_start` was deprecated in v1.6 and will be removed in v1.8. Called when the epoch begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_predict_epoch_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_predict_epoch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_predict_epoch_start(self, self.lightning_module) @@ -258,12 +258,12 @@ def on_predict_epoch_start(self) -> None: def on_predict_epoch_end(self, outputs: List[Any]) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_predict_epoch_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_predict_epoch_end` was deprecated in v1.6 and will be removed in v1.8. Called when the epoch ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_predict_epoch_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_predict_epoch_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_predict_epoch_end(self, self.lightning_module, outputs) @@ -271,12 +271,12 @@ def on_predict_epoch_end(self, outputs: List[Any]) -> None: def on_epoch_start(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_epoch_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_epoch_start` was deprecated in v1.6 and will be removed in v1.8. Called when either of train/val/test epoch begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_epoch_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_epoch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_epoch_start(self, self.lightning_module) @@ -284,12 +284,12 @@ def on_epoch_start(self): def on_epoch_end(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_epoch_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_epoch_end` was deprecated in v1.6 and will be removed in v1.8. Called when either of train/val/test epoch ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_epoch_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_epoch_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_epoch_end(self, self.lightning_module) @@ -297,12 +297,12 @@ def on_epoch_end(self): def on_train_start(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_train_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_train_start` was deprecated in v1.6 and will be removed in v1.8. Called when the train begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_train_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_train_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_train_start(self, self.lightning_module) @@ -310,12 +310,12 @@ def on_train_start(self): def on_train_end(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_train_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_train_end` was deprecated in v1.6 and will be removed in v1.8. Called when the train ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_train_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_train_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_train_end(self, self.lightning_module) @@ -323,12 +323,12 @@ def on_train_end(self): def on_pretrain_routine_start(self) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_pretrain_routine_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_pretrain_routine_start` was deprecated in v1.6 and will be removed in v1.8. Called when the pre-train routine begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_pretrain_routine_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_pretrain_routine_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_pretrain_routine_start(self, self.lightning_module) @@ -336,12 +336,12 @@ def on_pretrain_routine_start(self) -> None: def on_pretrain_routine_end(self) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_pretrain_routine_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_pretrain_routine_end` was deprecated in v1.6 and will be removed in v1.8. Called when the pre-train routine ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_pretrain_routine_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_pretrain_routine_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_pretrain_routine_end(self, self.lightning_module) @@ -349,12 +349,12 @@ def on_pretrain_routine_end(self) -> None: def on_batch_start(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_batch_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_batch_start` was deprecated in v1.6 and will be removed in v1.8. Called when the training batch begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_batch_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_batch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_batch_start(self, self.lightning_module) @@ -362,12 +362,12 @@ def on_batch_start(self): def on_batch_end(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_batch_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_batch_end` was deprecated in v1.6 and will be removed in v1.8. Called when the training batch ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_batch_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_batch_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_batch_end(self, self.lightning_module) @@ -376,12 +376,12 @@ def on_batch_end(self): def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8. Called when the training batch begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True): @@ -393,12 +393,12 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx=0): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.8. Called when the training batch ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_batch_end, "dataloader_idx", explicit=True): @@ -409,12 +409,12 @@ def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_ def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_validation_batch_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_validation_batch_start` was deprecated in v1.6 and will be removed in v1.8. Called when the validation batch begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_validation_batch_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_validation_batch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_validation_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) @@ -422,12 +422,12 @@ def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): def on_validation_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_validation_batch_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_validation_batch_end` was deprecated in v1.6 and will be removed in v1.8. Called when the validation batch ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_validation_batch_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_validation_batch_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_validation_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) @@ -435,12 +435,12 @@ def on_validation_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, datalo def on_test_batch_start(self, batch, batch_idx, dataloader_idx): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.8. Called when the test batch begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_test_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) @@ -448,12 +448,12 @@ def on_test_batch_start(self, batch, batch_idx, dataloader_idx): def on_test_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_test_batch_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_test_batch_end` was deprecated in v1.6 and will be removed in v1.8. Called when the test batch ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_test_batch_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_test_batch_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_test_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) @@ -461,12 +461,12 @@ def on_test_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_i def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_predict_batch_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_predict_batch_start` was deprecated in v1.6 and will be removed in v1.8. Called when the predict batch begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_predict_batch_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_predict_batch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_predict_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) @@ -474,12 +474,12 @@ def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int def on_predict_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_predict_batch_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_predict_batch_end` was deprecated in v1.6 and will be removed in v1.8. Called when the predict batch ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_predict_batch_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_predict_batch_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_predict_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) @@ -487,12 +487,12 @@ def on_predict_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, def on_validation_start(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_validation_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_validation_start` was deprecated in v1.6 and will be removed in v1.8. Called when the validation loop begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_validation_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_validation_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_validation_start(self, self.lightning_module) @@ -500,12 +500,12 @@ def on_validation_start(self): def on_validation_end(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_validation_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_validation_end` was deprecated in v1.6 and will be removed in v1.8. Called when the validation loop ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_validation_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_validation_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_validation_end(self, self.lightning_module) @@ -513,12 +513,12 @@ def on_validation_end(self): def on_test_start(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_test_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_test_start` was deprecated in v1.6 and will be removed in v1.8. Called when the test begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_test_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_test_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_test_start(self, self.lightning_module) @@ -526,12 +526,12 @@ def on_test_start(self): def on_test_end(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_test_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_test_end` was deprecated in v1.6 and will be removed in v1.8. Called when the test ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_test_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_test_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_test_end(self, self.lightning_module) @@ -539,12 +539,12 @@ def on_test_end(self): def on_predict_start(self) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_predict_start` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_predict_start` was deprecated in v1.6 and will be removed in v1.8. Called when predict begins. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_predict_start` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_predict_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_predict_start(self, self.lightning_module) @@ -552,12 +552,12 @@ def on_predict_start(self) -> None: def on_predict_end(self) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_predict_end` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_predict_end` was deprecated in v1.6 and will be removed in v1.8. Called when predict ends. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_predict_end` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_predict_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_predict_end(self, self.lightning_module) @@ -575,12 +575,12 @@ def on_keyboard_interrupt(self): def on_exception(self, exception: BaseException) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_exception` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_exception` was deprecated in v1.6 and will be removed in v1.8. Called when any trainer execution is interrupted by an exception. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_exception` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_exception` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_exception(self, self.lightning_module, exception) @@ -588,12 +588,12 @@ def on_exception(self, exception: BaseException) -> None: def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_save_checkpoint` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_save_checkpoint` was deprecated in v1.6 and will be removed in v1.8. Called when saving a model checkpoint. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_save_checkpoint` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_save_checkpoint` was deprecated in v1.6 and will be removed in v1.8." ) callback_states = {} for callback in self.callbacks: @@ -605,7 +605,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.8. Called when loading a model checkpoint. """ @@ -613,7 +613,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.8." ) callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") @@ -639,12 +639,12 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def on_before_backward(self, loss: torch.Tensor) -> None: r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_before_backward` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_before_backward` was deprecated in v1.6 and will be removed in v1.8. Called before ``loss.backward()``. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_before_backward` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_before_backward` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_before_backward(self, self.lightning_module, loss) @@ -652,12 +652,12 @@ def on_before_backward(self, loss: torch.Tensor) -> None: def on_after_backward(self): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_after_backward` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_after_backward` was deprecated in v1.6 and will be removed in v1.8. Called after loss.backward() and before optimizers do anything. """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_after_backward` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_after_backward` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_after_backward(self, self.lightning_module) @@ -665,12 +665,12 @@ def on_after_backward(self): def on_before_optimizer_step(self, optimizer, optimizer_idx): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.8. Called after on_after_backward() once the gradient is accumulated and before optimizer.step(). """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_before_optimizer_step(self, self.lightning_module, optimizer, optimizer_idx) @@ -678,12 +678,12 @@ def on_before_optimizer_step(self, optimizer, optimizer_idx): def on_before_zero_grad(self, optimizer): r""" .. deprecated:: v1.6 - `TrainerCallbackHookMixin.on_before_zero_grad` was deprecated in v1.6 and will be removed in v1.7. + `TrainerCallbackHookMixin.on_before_zero_grad` was deprecated in v1.6 and will be removed in v1.8. Called after optimizer.step() and before optimizer.zero_grad(). """ rank_zero_deprecation( - "`TrainerCallbackHookMixin.on_before_zero_grad` was deprecated in v1.6 and will be removed in v1.7." + "`TrainerCallbackHookMixin.on_before_zero_grad` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_before_zero_grad(self, self.lightning_module, optimizer) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3740be974b597..c36dd54102c7c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -99,12 +99,14 @@ from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import ( _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT, EVAL_DATALOADERS, LRSchedulerTypeUnion, + STEP_OUTPUT, TRAIN_DATALOADERS, ) from pytorch_lightning.utilities.warnings import PossibleUserWarning @@ -1554,11 +1556,12 @@ def _call_callback_hooks( pl_module._current_fx_name = hook_name # TODO: remove if block in v1.7 - if hook_name in ("on_train_batch_start", "on_train_batch_end"): - fn = getattr(self, hook_name) - if callable(fn): - with self.profiler.profile(hook_name): - fn(*args, **kwargs) + if hook_name == "on_train_batch_start": + with self.profiler.profile(hook_name): + self._on_train_batch_start(*args, **kwargs) + elif hook_name == "on_train_batch_end": + with self.profiler.profile(hook_name): + self._on_train_batch_end(*args, **kwargs) else: for callback in self.callbacks: fn = getattr(callback, hook_name) @@ -1570,6 +1573,26 @@ def _call_callback_hooks( # restore current_fx when nested context pl_module._current_fx_name = prev_fx_name + # _on_train_batch_start and _on_train_batch_end are needed because of two different deprecations affecting + # the original functions in TrainerCallbackHookMixin: #9816 and #11148 + # TODO: Delete this in v1.7 (deprecations: #9816 and #11148) + def _on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): + """Called when the training batch begins.""" + for callback in self.callbacks: + if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True): + callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0) + else: + callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx) + + # TODO: Delete this in v1.7 (deprecations: #9816 and #11148) + def _on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx=0): + """Called when the training batch ends.""" + for callback in self.callbacks: + if is_param_in_hook_signature(callback.on_train_batch_end, "dataloader_idx", explicit=True): + callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, 0) + else: + callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx) + # TODO: rename to _call_strategy_hook and eventually no longer need this def _call_ttp_hook( self, diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 1b1f0a401819c..986491a306ce7 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -551,94 +551,3 @@ def post_dispatch(self, trainer): with pytest.deprecated_call(match=escape("`CustomPlugin.post_dispatch()` has been deprecated in v1.6")): CustomPlugin(torch.device("cpu")) - - -def test_v1_7_0_deprecate_trainer_callback_hook_mixin(): - methods_with_self = [ - "on_before_accelerator_backend_setup", - "on_configure_sharded_model", - "on_init_start", - "on_init_end", - "on_fit_start", - "on_fit_end", - "on_sanity_check_start", - "on_sanity_check_end", - "on_train_epoch_start", - "on_train_epoch_end", - "on_validation_epoch_start", - "on_validation_epoch_end", - "on_test_epoch_start", - "on_test_epoch_end", - "on_predict_epoch_start", - "on_epoch_start", - "on_epoch_end", - "on_train_start", - "on_train_end", - #"on_pretrain_routine_start", - "on_pretrain_routine_end", - "on_batch_start", - "on_batch_end", - "on_validation_start", - "on_validation_end", - "on_test_start", - "on_test_end", - "on_predict_start", - "on_predict_end", - "on_keyboard_interrupt", - "on_after_backward", - ] - methods_with_stage = [ - "setup", - "teardown", - ] - methods_with_batch_batch_idx_dataloader_idx = [ - "on_train_batch_start", - "on_validation_batch_start", - "on_test_batch_start", - "on_predict_batch_start", - ] - methods_with_outputs_batch_batch_idx_dataloader_idx = [ - "on_train_batch_end", - "on_validation_batch_end", - "on_test_batch_end", - "on_predict_batch_end", - ] - methods_with_checkpoint = ["on_save_checkpoint", "on_load_checkpoint"] - trainer = Trainer( - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - enable_progress_bar=False, - logger=False, - ) - model = BoringModel() - for method_name in methods_with_self: - fn = getattr(trainer, method_name) - with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): - fn() - # for method_name in methods_with_stage: - # fn = getattr(trainer, method_name) - # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): - # fn(stage="test") - # for method_name in methods_with_batch_batch_idx_dataloader_idx: - # fn = getattr(trainer, method_name) - # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): - # fn(batch, batch_idx=0, dataloader_idx=0) - # for method_name in methods_with_outputs_batch_batch_idx_dataloader_idx: - # fn = getattr(trainer, method_name) - # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): - # fn(outputs, batch, batch_idx=0, dataloader_idx=0) - # for method_name in methods_with_outputs_batch_batch_idx_dataloader_idx: - # fn = getattr(trainer, method_name) - # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): - # fn(checkpoint) - # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): - # trainer.on_predict_epoch_end(outputs) - # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): - # trainer.on_exception(Exception) - # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): - # trainer.on_before_backward(loss=torch.tensor([[1.0, -1.0], [1.0, -1.0]])) - # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): - # trainer.on_before_optimizer_step(optimizer, optimizer_idx) - # with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.7"): - # trainer.on_before_zero_grad(optimizer) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 67ecb736ed6c4..b87a43b8f4939 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -16,6 +16,7 @@ import pytest import torch +from torch import optim from pytorch_lightning import Callback, Trainer from pytorch_lightning.utilities import rank_zero_warn @@ -137,3 +138,97 @@ def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir): match=r"`Trainer.should_rank_save_checkpoint` is deprecated in v1.6 and will be removed in 1.8." ): _ = trainer.should_rank_save_checkpoint + + +def test_v1_8_0_deprecate_trainer_callback_hook_mixin(): + methods_with_self = [ + "on_before_accelerator_backend_setup", + "on_configure_sharded_model", + "on_init_start", + "on_init_end", + "on_fit_start", + "on_fit_end", + "on_sanity_check_start", + "on_sanity_check_end", + "on_train_epoch_start", + "on_train_epoch_end", + "on_validation_epoch_start", + "on_validation_epoch_end", + "on_test_epoch_start", + "on_test_epoch_end", + "on_predict_epoch_start", + "on_epoch_start", + "on_epoch_end", + "on_train_start", + "on_train_end", + "on_pretrain_routine_start", + "on_pretrain_routine_end", + "on_batch_start", + "on_batch_end", + "on_validation_start", + "on_validation_end", + "on_test_start", + "on_test_end", + "on_predict_start", + "on_predict_end", + "on_after_backward", + ] + methods_with_stage = [ + "setup", + "teardown", + ] + methods_with_batch_batch_idx_dataloader_idx = [ + "on_train_batch_start", + "on_validation_batch_start", + "on_test_batch_start", + "on_predict_batch_start", + ] + methods_with_outputs_batch_batch_idx_dataloader_idx = [ + "on_train_batch_end", + "on_validation_batch_end", + "on_test_batch_end", + "on_predict_batch_end", + ] + methods_with_checkpoint = ["on_save_checkpoint", "on_load_checkpoint"] + trainer = Trainer( + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + enable_progress_bar=False, + logger=False, + ) + model = BoringModel() + # need to attach model to trainer for testing of `on_pretrain_routine_start` + trainer.fit(model) + for method_name in methods_with_self: + fn = getattr(trainer, method_name, None) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + fn() + for method_name in methods_with_stage: + fn = getattr(trainer, method_name) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + fn(stage="test") + for method_name in methods_with_batch_batch_idx_dataloader_idx: + fn = getattr(trainer, method_name) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + fn(batch={}, batch_idx=0, dataloader_idx=0) + for method_name in methods_with_outputs_batch_batch_idx_dataloader_idx: + fn = getattr(trainer, method_name) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + fn(outputs=torch.tensor([[1.0, -1.0], [1.0, -1.0]]), batch={}, batch_idx=0, dataloader_idx=0) + for method_name in methods_with_checkpoint: + fn = getattr(trainer, method_name) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + fn(checkpoint={}) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + trainer.on_predict_epoch_end(outputs=torch.tensor([[1.0, -1.0], [1.0, -1.0]])) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + trainer.on_exception(exception=Exception) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + trainer.on_before_backward(loss=torch.tensor([[1.0, -1.0], [1.0, -1.0]])) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + trainer.on_before_optimizer_step( + optimizer=optim.SGD(model.parameters(), lr=0.01, momentum=0.9), optimizer_idx=0 + ) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): + trainer.on_before_zero_grad(optimizer=optim.SGD(model.parameters(), lr=0.01, momentum=0.9)) From 9472aa7e54fbbbcdf79a3af8926395da5f268fad Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Tue, 21 Dec 2021 08:52:37 -0800 Subject: [PATCH 7/7] fix docstring --- pytorch_lightning/trainer/trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fac1873331e57..6b6b3abed91c2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1571,11 +1571,11 @@ def _call_callback_hooks( # restore current_fx when nested context pl_module._current_fx_name = prev_fx_name - # _on_train_batch_start and _on_train_batch_end are needed because of two different deprecations affecting - # the original functions in TrainerCallbackHookMixin: #9816 and #11148 # TODO: Delete this in v1.7 (deprecations: #9816 and #11148) def _on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): - """Called when the training batch begins.""" + r"""Called when the training batch begins. This function is needed because of two different deprecations affecting + the original function in TrainerCallbackHookMixin: #9816 and #11148. + """ for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True): callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0) @@ -1584,7 +1584,9 @@ def _on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): # TODO: Delete this in v1.7 (deprecations: #9816 and #11148) def _on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx=0): - """Called when the training batch ends.""" + r"""Called when the training batch ends. This function is needed because of two different deprecations affecting + the original function in TrainerCallbackHookMixin: #9816 and #11148. + """ for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_batch_end, "dataloader_idx", explicit=True): callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, 0)