Skip to content

Commit

Permalink
Update hooks pseudocode (#7713)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored May 27, 2021
1 parent 04dcb17 commit 906c067
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 30 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
* Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507))
* Refactored the logic around manual and automatic optimization inside the optimizer loop ([#7526](https://github.com/PyTorchLightning/pytorch-lightning/pull/7526))

* Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
* Refactored "should run validation" logic when the trainer is signaled to stop ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701))

- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))

Expand Down
69 changes: 40 additions & 29 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,9 @@ override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`:

Hooks
^^^^^
This is the pseudocode to describe how all the hooks are called during a call to ``.fit()``.
This is the pseudocode to describe the structure of :meth:`~pytorch_lightning.trainer.Trainer.fit`.
The inputs and outputs of each function are not represented for simplicity. Please check each function's API reference
for more information.

.. code-block:: python
Expand All @@ -1075,36 +1077,41 @@ This is the pseudocode to describe how all the hooks are called during a call to
configure_callbacks()
on_fit_start()
for gpu/tpu in gpu/tpus:
train_on_device(model.copy())
on_fit_end()
with parallel(devices):
# devices can be GPUs, TPUs, ...
train_on_device(model)
def train_on_device(model):
# setup is called PER DEVICE
setup()
# called PER DEVICE
on_fit_start()
setup('fit')
configure_optimizers()
on_pretrain_routine_start()
on_pretrain_routine_end()
# the sanity check runs here
on_train_start()
for epoch in epochs:
train_loop()
on_train_end()
teardown()
on_fit_end()
teardown('fit')
def train_loop():
on_epoch_start()
on_train_epoch_start()
train_outs = []
for train_batch in train_dataloader():
for batch in train_dataloader():
on_train_batch_start()
# ----- train_step methods -------
out = training_step(batch)
train_outs.append(out)
on_before_batch_transfer()
transfer_batch_to_device()
on_after_batch_transfer()
loss = out.loss
training_step()
on_before_zero_grad()
optimizer_zero_grad()
Expand All @@ -1114,38 +1121,42 @@ This is the pseudocode to describe how all the hooks are called during a call to
optimizer_step()
on_train_batch_end(out)
on_train_batch_end()
if should_check_val:
val_loop()
# end training epoch
training_epoch_end(outs)
on_train_epoch_end(outs)
training_epoch_end()
on_train_epoch_end()
on_epoch_end()
def val_loop():
model.eval()
on_validation_model_eval() # calls `model.eval()`
torch.set_grad_enabled(False)
on_validation_start()
on_epoch_start()
on_validation_epoch_start()
val_outs = []
for val_batch in val_dataloader():
for batch in val_dataloader():
on_validation_batch_start()
# -------- val step methods -------
out = validation_step(val_batch)
val_outs.append(out)
on_before_batch_transfer()
transfer_batch_to_device()
on_after_batch_transfer()
validation_step()
on_validation_batch_end(out)
on_validation_batch_end()
validation_epoch_end()
validation_epoch_end(val_outs)
on_validation_epoch_end()
on_epoch_end()
on_validation_end()
# set up for train
model.train()
on_validation_model_train() # calls `model.train()`
torch.set_grad_enabled(True)
backward
Expand Down
22 changes: 22 additions & 0 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,18 @@ def __init__(self):
'on_validation_batch_end',
]

def prepare_data(self):
self.called.append("prepare_data")
return super().prepare_data()

def configure_callbacks(self):
self.called.append("configure_callbacks")
return super().configure_callbacks()

def configure_optimizers(self):
self.called.append("configure_optimizers")
return super().configure_optimizers()

def training_step(self, *args, **kwargs):
self.called.append("training_step")
return super().training_step(*args, **kwargs)
Expand Down Expand Up @@ -451,7 +463,10 @@ def test_trainer_model_hook_system_fit(tmpdir):
assert model.called == []
trainer.fit(model)
expected = [
'prepare_data',
'configure_callbacks',
'setup_fit',
'configure_optimizers',
'on_fit_start',
'on_pretrain_routine_start',
'on_pretrain_routine_end',
Expand Down Expand Up @@ -504,7 +519,10 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir):
assert model.called == []
trainer.fit(model)
expected = [
'prepare_data',
'configure_callbacks',
'setup_fit',
'configure_optimizers',
'on_fit_start',
'on_pretrain_routine_start',
'on_pretrain_routine_end',
Expand Down Expand Up @@ -535,6 +553,8 @@ def test_trainer_model_hook_system_validate(tmpdir):
assert model.called == []
trainer.validate(model, verbose=False)
expected = [
'prepare_data',
'configure_callbacks',
'setup_validate',
'on_validation_model_eval',
'on_validation_start',
Expand Down Expand Up @@ -567,6 +587,8 @@ def test_trainer_model_hook_system_test(tmpdir):
assert model.called == []
trainer.test(model, verbose=False)
expected = [
'prepare_data',
'configure_callbacks',
'setup_test',
'on_test_model_eval',
'on_test_start',
Expand Down

0 comments on commit 906c067

Please sign in to comment.