diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 274a19333f51a..5624d376bf816 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1538,7 +1538,7 @@ def _call_accelerator_hook( **kwargs: Any, ) -> Optional[Any]: self.lightning_module._current_fx_name = hook_name - fn = getattr(self.accelerator, hook_name) + fn = getattr(self.training_type_plugin, hook_name) if not callable(fn): return None