diff --git a/CHANGELOG.md b/CHANGELOG.md index 3da34016f9ec0..985c27195a266 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -239,6 +239,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `HorovodPlugin.start_{training,evaluating,predicting}` hooks ([#10989](https://github.com/PyTorchLightning/pytorch-lightning/pull/10989)) +- Removed `Accelerator.on_train_start` ([#10999](https://github.com/PyTorchLightning/pytorch-lightning/pull/10999)) ### Fixed diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 69a5bc1091c28..18fd855c94a60 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -113,10 +113,6 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """ raise NotImplementedError - def on_train_start(self) -> None: - """Called when train begins.""" - return self.training_type_plugin.on_train_start() - @staticmethod @abstractmethod def auto_device_count() -> int: diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 49d0770e54ff0..c6c82d83c32f5 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -45,10 +45,7 @@ def setup_environment(self) -> None: def setup(self, trainer: "pl.Trainer") -> None: self.set_nvidia_flags(trainer.local_rank) - return super().setup(trainer) - - def on_train_start(self) -> None: - super().on_train_start() + super().setup(trainer) # clear cache before training torch.cuda.empty_cache() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index f15b0b954361a..0bcd10f916527 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -195,7 +195,7 @@ def on_run_start(self) -> None: # type: ignore[override] self._results.to(device=self.trainer.lightning_module.device) self.trainer._call_callback_hooks("on_train_start") self.trainer._call_lightning_module_hook("on_train_start") - self.trainer._call_accelerator_hook("on_train_start") + self.trainer._call_ttp_hook("on_train_start") def on_advance_start(self) -> None: # type: ignore[override] """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d6a5910a630db..dde8edce21b77 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1592,29 +1592,6 @@ def _call_ttp_hook( return output - # TODO: eventually no longer need this - def _call_accelerator_hook( - self, - hook_name: str, - *args: Any, - **kwargs: Any, - ) -> Any: - pl_module = self.lightning_module - prev_fx_name = pl_module._current_fx_name - pl_module._current_fx_name = hook_name - - fn = getattr(self.accelerator, hook_name) - if not callable(fn): - return - - with self.profiler.profile(hook_name): - output = fn(*args, **kwargs) - - # restore current_fx when nested context - pl_module._current_fx_name = prev_fx_name - - return output - @staticmethod def _parse_devices( gpus: Optional[Union[List[int], str, int]],