Skip to content

Commit

Permalink
Re-design call_hook interface (#10575)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniellepintz authored Dec 4, 2021
1 parent a28b4cd commit 6043179
Show file tree
Hide file tree
Showing 18 changed files with 224 additions and 360 deletions.
4 changes: 3 additions & 1 deletion pl_examples/loop_examples/yielding_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def _training_step(self, generator):
training_step_output = next(generator)
self.trainer.training_type_plugin.post_training_step()

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
training_step_output = ttp_output if model_output is None else model_output

# The closure result takes care of properly detaching the loss for logging and peforms
# some additional checks that the output format is correct.
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def log(
value = apply_to_collection(value, numbers.Number, self.__to_tensor)

if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running it's first batch) the hook name has changed
# if we started a new epoch (running its first batch) the hook name has changed
# reset any tensors for the new hook name
results.reset(metrics=False, fx=self._current_fx_name)

Expand Down
47 changes: 27 additions & 20 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT


Expand Down Expand Up @@ -170,16 +169,20 @@ def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
self._results.to(device=self.trainer.lightning_module.device)

if self.trainer.testing:
self.trainer.call_hook("on_test_start", *args, **kwargs)
self.trainer._call_callback_hooks("on_test_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_start", *args, **kwargs)
self.trainer._call_ttp_hook("on_test_start", *args, **kwargs)
else:
self.trainer.call_hook("on_validation_start", *args, **kwargs)
self.trainer._call_callback_hooks("on_validation_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_start", *args, **kwargs)
self.trainer._call_ttp_hook("on_validation_start", *args, **kwargs)

def _on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode."""
if self.trainer.testing:
self.trainer.call_hook("on_test_model_eval")
self.trainer._call_lightning_module_hook("on_test_model_eval")
else:
self.trainer.call_hook("on_validation_model_eval")
self.trainer._call_lightning_module_hook("on_validation_model_eval")

def _on_evaluation_model_train(self) -> None:
"""Sets model to train mode."""
Expand All @@ -192,22 +195,29 @@ def _on_evaluation_model_train(self) -> None:
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_end`` hook."""
if self.trainer.testing:
self.trainer.call_hook("on_test_end", *args, **kwargs)
self.trainer._call_callback_hooks("on_test_end", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_end", *args, **kwargs)
self.trainer._call_ttp_hook("on_test_end", *args, **kwargs)
else:
self.trainer.call_hook("on_validation_end", *args, **kwargs)
self.trainer._call_callback_hooks("on_validation_end", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_end", *args, **kwargs)
self.trainer._call_ttp_hook("on_validation_end", *args, **kwargs)

# reset the logger connector state
self.trainer.logger_connector.reset_results()

def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start", *args, **kwargs)
self.trainer._call_callback_hooks("on_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_epoch_start", *args, **kwargs)

if self.trainer.testing:
self.trainer.call_hook("on_test_epoch_start", *args, **kwargs)
self.trainer._call_callback_hooks("on_test_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_epoch_start", *args, **kwargs)
else:
self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)
self.trainer._call_callback_hooks("on_validation_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_epoch_start", *args, **kwargs)

def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
"""Runs ``{validation/test}_epoch_end``"""
Expand All @@ -220,20 +230,17 @@ def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
)

# call the model epoch end
model = self.trainer.lightning_module
if self.trainer.testing:
if is_overridden("test_epoch_end", model):
model._current_fx_name = "test_epoch_end"
model.test_epoch_end(output_or_outputs)

self.trainer._call_lightning_module_hook("test_epoch_end", output_or_outputs)
else:
if is_overridden("validation_epoch_end", model):
model._current_fx_name = "validation_epoch_end"
model.validation_epoch_end(output_or_outputs)
self.trainer._call_lightning_module_hook("validation_epoch_end", output_or_outputs)

def _on_evaluation_epoch_end(self) -> None:
"""Runs ``on_{validation/test}_epoch_end`` hook."""
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
self.trainer.call_hook(hook_name)
self.trainer.call_hook("on_epoch_end")
self.trainer._call_callback_hooks(hook_name)
self.trainer._call_lightning_module_hook(hook_name)

self.trainer._call_callback_hooks("on_epoch_end")
self.trainer._call_lightning_module_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()
15 changes: 11 additions & 4 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,12 @@ def _on_predict_start(self) -> None:
self.trainer.lightning_module.zero_grad()

# hook
self.trainer.call_hook("on_predict_start")
self.trainer.call_hook("on_predict_epoch_start")
self.trainer._call_callback_hooks("on_predict_start")
self.trainer._call_lightning_module_hook("on_predict_start")
self.trainer._call_ttp_hook("on_predict_start")

self.trainer._call_callback_hooks("on_predict_epoch_start")
self.trainer._call_lightning_module_hook("on_predict_epoch_start")

def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
"""Calls ``on_predict_epoch_end`` hook.
Expand All @@ -118,7 +122,8 @@ def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
"""
results = self.predictions

self.trainer.call_hook("on_predict_epoch_end", results)
self.trainer._call_callback_hooks("on_predict_epoch_end", results)
self.trainer._call_lightning_module_hook("on_predict_epoch_end", results)

if self.return_predictions:
return results[0] if self.num_dataloaders == 1 else results
Expand All @@ -130,7 +135,9 @@ def _on_predict_end(self) -> None:
self.epoch_batch_indices = []

# hook
self.trainer.call_hook("on_predict_end")
self.trainer._call_callback_hooks("on_predict_end")
self.trainer._call_lightning_module_hook("on_predict_end")
self.trainer._call_ttp_hook("on_predict_end")

def _on_predict_model_eval(self) -> None:
"""Calls ``on_predict_model_eval`` hook."""
Expand Down
21 changes: 11 additions & 10 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,20 +218,18 @@ def _evaluation_step(self, **kwargs: Any) -> Optional[STEP_OUTPUT]:
the outputs of the step
"""
if self.trainer.testing:
self.trainer.lightning_module._current_fx_name = "test_step"
with self.trainer.profiler.profile("test_step"):
output = self.trainer.accelerator.test_step(*kwargs.values())
output = self.trainer._call_accelerator_hook("test_step", *kwargs.values())
else:
self.trainer.lightning_module._current_fx_name = "validation_step"
with self.trainer.profiler.profile("validation_step"):
output = self.trainer.accelerator.validation_step(*kwargs.values())
output = self.trainer._call_accelerator_hook("validation_step", *kwargs.values())

return output

def _evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
"""Calls the `{validation/test}_step_end` hook."""
hook_name = "test_step_end" if self.trainer.testing else "validation_step_end"
output = self.trainer.call_hook(hook_name, *args, **kwargs)
model_output = self.trainer._call_lightning_module_hook(hook_name, *args, **kwargs)
ttp_output = self.trainer._call_ttp_hook(hook_name, *args, **kwargs)
output = ttp_output if model_output is None else model_output
return output

def _on_evaluation_batch_start(self, **kwargs: Any) -> None:
Expand All @@ -249,9 +247,11 @@ def _on_evaluation_batch_start(self, **kwargs: Any) -> None:

kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these
if self.trainer.testing:
self.trainer.call_hook("on_test_batch_start", *kwargs.values())
self.trainer._call_callback_hooks("on_test_batch_start", *kwargs.values())
self.trainer._call_lightning_module_hook("on_test_batch_start", *kwargs.values())
else:
self.trainer.call_hook("on_validation_batch_start", *kwargs.values())
self.trainer._call_callback_hooks("on_validation_batch_start", *kwargs.values())
self.trainer._call_lightning_module_hook("on_validation_batch_start", *kwargs.values())

def _on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT], **kwargs: Any) -> None:
"""The ``on_{validation/test}_batch_end`` hook.
Expand All @@ -264,7 +264,8 @@ def _on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT], **kwargs: Any)
"""
kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these
hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end"
self.trainer.call_hook(hook_name, output, *kwargs.values())
self.trainer._call_callback_hooks(hook_name, output, *kwargs.values())
self.trainer._call_lightning_module_hook(hook_name, output, *kwargs.values())

self.trainer.logger_connector.on_batch_end()

Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,20 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
# extract batch_indices and store them
self.current_batch_indices = self._seen_batch_indices[batch_idx] if self._seen_batch_indices else []

model_ref = self.trainer.lightning_module

self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx)
self.trainer._call_callback_hooks("on_predict_batch_start", batch, batch_idx, dataloader_idx)
self.trainer._call_lightning_module_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx)

self.batch_progress.increment_started()

model_ref._current_fx_name = "predict_step"
predictions = self.trainer.accelerator.predict_step(*step_kwargs.values())
predictions = self.trainer._call_accelerator_hook("predict_step", *step_kwargs.values())

self.batch_progress.increment_processed()

if predictions is None:
self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")

self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)
self.trainer._call_callback_hooks("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)
self.trainer._call_lightning_module_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)

self.batch_progress.increment_completed()

Expand Down
33 changes: 23 additions & 10 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,11 @@ def reset(self) -> None:
def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
# hook
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")
self.trainer._call_callback_hooks("on_epoch_start")
self.trainer._call_lightning_module_hook("on_epoch_start")

self.trainer._call_callback_hooks("on_train_epoch_start")
self.trainer._call_lightning_module_hook("on_train_epoch_start")
self.trainer.fit_loop.epoch_progress.increment_started()

self._reload_dataloader_state_dict(data_fetcher)
Expand Down Expand Up @@ -165,7 +168,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
batch_output = []
else:
# hook
self.trainer.call_hook("on_batch_start")
self.trainer._call_callback_hooks("on_batch_start")

# TODO: Update this in v1.7 (deprecation: #9816)
model_fx = self.trainer.lightning_module.on_train_batch_start
Expand All @@ -176,7 +179,12 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
)

# hook
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs)
model_response = self.trainer._call_lightning_module_hook(
"on_train_batch_start", batch, batch_idx, **extra_kwargs
)
ttp_response = self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
response = ttp_response if model_response is None else model_response
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration
Expand Down Expand Up @@ -207,8 +215,11 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)
self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
self.trainer.call_hook("on_batch_end")
self.trainer._call_callback_hooks("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
self.trainer._call_lightning_module_hook(
"on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs
)
self.trainer._call_callback_hooks("on_batch_end")
self.trainer.logger_connector.on_batch_end()

self.batch_progress.increment_completed()
Expand Down Expand Up @@ -276,8 +287,7 @@ def on_run_end(self) -> None:
)
# run lightning module hook training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = "training_epoch_end"
epoch_end_outputs = model.training_epoch_end(epoch_end_outputs)
epoch_end_outputs = self.trainer._call_lightning_module_hook("training_epoch_end", epoch_end_outputs)
if epoch_end_outputs is not None:
raise MisconfigurationException(
"`training_epoch_end` expects a return of None. "
Expand All @@ -289,8 +299,11 @@ def on_run_end(self) -> None:
self.trainer.fit_loop.epoch_progress.increment_processed()

# call train epoch end hooks
self.trainer.call_hook("on_train_epoch_end")
self.trainer.call_hook("on_epoch_end")
self.trainer._call_callback_hooks("on_train_epoch_end")
self.trainer._call_lightning_module_hook("on_train_epoch_end")

self.trainer._call_callback_hooks("on_epoch_end")
self.trainer._call_lightning_module_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()

if self._num_ready_batches_reached():
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def on_run_start(self) -> None: # type: ignore[override]
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
self._is_fresh_start_epoch = True
self._results.to(device=self.trainer.lightning_module.device)
self.trainer.call_hook("on_train_start")
self.trainer._call_callback_hooks("on_train_start")
self.trainer._call_lightning_module_hook("on_train_start")
self.trainer._call_accelerator_hook("on_train_start")

def on_advance_start(self) -> None: # type: ignore[override]
"""Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and
Expand Down Expand Up @@ -248,7 +250,9 @@ def on_run_end(self) -> None:
self.current_epoch = max(self.current_epoch - 1, 0)

# hook
self.trainer.call_hook("on_train_end")
self.trainer._call_callback_hooks("on_train_end")
self.trainer._call_lightning_module_hook("on_train_end")
self.trainer._call_ttp_hook("on_train_end")

# give accelerators a chance to finish
self.trainer.training_type_plugin.on_train_end()
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/loops/optimization/manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,14 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
)

# manually capture logged metrics
lightning_module._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(*step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()

del step_kwargs

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
training_step_output = ttp_output if model_output is None else model_output
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)

result = self.output_result_cls.from_training_step_output(training_step_output)
Expand Down
Loading

0 comments on commit 6043179

Please sign in to comment.