Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove _call_accelerator_hook Trainer method #10999

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,6 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""
raise NotImplementedError

def on_train_start(self) -> None:
"""Called when train begins."""
return self.training_type_plugin.on_train_start()

@staticmethod
@abstractmethod
def auto_device_count() -> int:
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ def setup_environment(self) -> None:

def setup(self, trainer: "pl.Trainer") -> None:
self.set_nvidia_flags(trainer.local_rank)
return super().setup(trainer)

def on_train_start(self) -> None:
super().on_train_start()
super().setup(trainer)
# clear cache before training
torch.cuda.empty_cache()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def on_run_start(self) -> None: # type: ignore[override]
self._results.to(device=self.trainer.lightning_module.device)
self.trainer._call_callback_hooks("on_train_start")
self.trainer._call_lightning_module_hook("on_train_start")
self.trainer._call_accelerator_hook("on_train_start")
self.trainer._call_ttp_hook("on_train_start")

def on_advance_start(self) -> None: # type: ignore[override]
"""Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and
Expand Down
23 changes: 0 additions & 23 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,29 +1592,6 @@ def _call_ttp_hook(

return output

# TODO: eventually no longer need this
def _call_accelerator_hook(
self,
hook_name: str,
*args: Any,
**kwargs: Any,
) -> Any:
pl_module = self.lightning_module
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name

fn = getattr(self.accelerator, hook_name)
if not callable(fn):
return

with self.profiler.profile(hook_name):
output = fn(*args, **kwargs)

# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name

return output

@staticmethod
def _parse_devices(
gpus: Optional[Union[List[int], str, int]],
Expand Down