diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5a71defcfd925..fd6ef13bf7b27 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1239,9 +1239,10 @@ def call_hook( # next call hook in lightningModule output = None - model_fx = getattr(pl_module, hook_name, None) - if callable(model_fx): - output = model_fx(*args, **kwargs) + if is_overridden(hook_name, pl_module): + model_fx = getattr(pl_module, hook_name) + if callable(model_fx): + output = model_fx(*args, **kwargs) # call the accelerator hook if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name):