From 906c067b07a7caa148615880792e65bb779ee3c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 27 May 2021 12:27:26 +0200 Subject: [PATCH] Update hooks pseudocode (#7713) --- CHANGELOG.md | 3 +- docs/source/common/lightning_module.rst | 69 ++++++++++++++----------- tests/models/test_hooks.py | 22 ++++++++ 3 files changed, 64 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 69cd38ea5ece9..58cce920af23e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -74,7 +74,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506)) * Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507)) * Refactored the logic around manual and automatic optimization inside the optimizer loop ([#7526](https://github.com/PyTorchLightning/pytorch-lightning/pull/7526)) - + * Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682)) + * Refactored "should run validation" logic when the trainer is signaled to stop ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701)) - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 295d231ca5ac3..ad8652f4460bb 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1064,7 +1064,9 @@ override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`: Hooks ^^^^^ -This is the pseudocode to describe how all the hooks are called during a call to ``.fit()``. +This is the pseudocode to describe the structure of :meth:`~pytorch_lightning.trainer.Trainer.fit`. +The inputs and outputs of each function are not represented for simplicity. Please check each function's API reference +for more information. .. code-block:: python @@ -1075,36 +1077,41 @@ This is the pseudocode to describe how all the hooks are called during a call to configure_callbacks() - on_fit_start() - - for gpu/tpu in gpu/tpus: - train_on_device(model.copy()) - - on_fit_end() + with parallel(devices): + # devices can be GPUs, TPUs, ... + train_on_device(model) def train_on_device(model): - # setup is called PER DEVICE - setup() + # called PER DEVICE + on_fit_start() + setup('fit') configure_optimizers() + on_pretrain_routine_start() + on_pretrain_routine_end() + + # the sanity check runs here + on_train_start() for epoch in epochs: train_loop() + on_train_end() - teardown() + on_fit_end() + teardown('fit') def train_loop(): on_epoch_start() on_train_epoch_start() - train_outs = [] - for train_batch in train_dataloader(): + + for batch in train_dataloader(): on_train_batch_start() - # ----- train_step methods ------- - out = training_step(batch) - train_outs.append(out) + on_before_batch_transfer() + transfer_batch_to_device() + on_after_batch_transfer() - loss = out.loss + training_step() on_before_zero_grad() optimizer_zero_grad() @@ -1114,38 +1121,42 @@ This is the pseudocode to describe how all the hooks are called during a call to optimizer_step() - on_train_batch_end(out) + on_train_batch_end() if should_check_val: val_loop() - # end training epoch - training_epoch_end(outs) - on_train_epoch_end(outs) + training_epoch_end() + + on_train_epoch_end() on_epoch_end() def val_loop(): - model.eval() + on_validation_model_eval() # calls `model.eval()` torch.set_grad_enabled(False) + on_validation_start() on_epoch_start() on_validation_epoch_start() - val_outs = [] - for val_batch in val_dataloader(): + + for batch in val_dataloader(): on_validation_batch_start() - # -------- val step methods ------- - out = validation_step(val_batch) - val_outs.append(out) + on_before_batch_transfer() + transfer_batch_to_device() + on_after_batch_transfer() + + validation_step() - on_validation_batch_end(out) + on_validation_batch_end() + validation_epoch_end() - validation_epoch_end(val_outs) on_validation_epoch_end() on_epoch_end() + on_validation_end() # set up for train - model.train() + on_validation_model_train() # calls `model.train()` torch.set_grad_enabled(True) backward diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 913f403a14dd3..60354c987fab3 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -255,6 +255,18 @@ def __init__(self): 'on_validation_batch_end', ] + def prepare_data(self): + self.called.append("prepare_data") + return super().prepare_data() + + def configure_callbacks(self): + self.called.append("configure_callbacks") + return super().configure_callbacks() + + def configure_optimizers(self): + self.called.append("configure_optimizers") + return super().configure_optimizers() + def training_step(self, *args, **kwargs): self.called.append("training_step") return super().training_step(*args, **kwargs) @@ -451,7 +463,10 @@ def test_trainer_model_hook_system_fit(tmpdir): assert model.called == [] trainer.fit(model) expected = [ + 'prepare_data', + 'configure_callbacks', 'setup_fit', + 'configure_optimizers', 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', @@ -504,7 +519,10 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): assert model.called == [] trainer.fit(model) expected = [ + 'prepare_data', + 'configure_callbacks', 'setup_fit', + 'configure_optimizers', 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', @@ -535,6 +553,8 @@ def test_trainer_model_hook_system_validate(tmpdir): assert model.called == [] trainer.validate(model, verbose=False) expected = [ + 'prepare_data', + 'configure_callbacks', 'setup_validate', 'on_validation_model_eval', 'on_validation_start', @@ -567,6 +587,8 @@ def test_trainer_model_hook_system_test(tmpdir): assert model.called == [] trainer.test(model, verbose=False) expected = [ + 'prepare_data', + 'configure_callbacks', 'setup_test', 'on_test_model_eval', 'on_test_start',