Skip to content

Commit

Permalink
2/n Move Accelerator into strategy - remove dispatch functions from A…
Browse files Browse the repository at this point in the history
…ccelerator (#10885)


Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
3 people authored Dec 6, 2021
1 parent 7914e5c commit 2fc64e9
Show file tree
Hide file tree
Showing 12 changed files with 20 additions and 22 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed method `setup_optimizers_in_pre_dispatch` from the `strategies` and achieve the same logic in `setup` and `pre_dispatch` methods ([#10906](https://github.com/PyTorchLightning/pytorch-lightning/pull/10906))


- Removed methods `pre_dispatch`, `dispatch` and `post_dispatch` from the `Accelerator` ([#10885](https://github.com/PyTorchLightning/pytorch-lightning/pull/10885))


### Fixed

- Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611))
Expand Down
15 changes: 0 additions & 15 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
"""
self.training_type_plugin.setup(trainer)

def pre_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin._move_optimizer_state()
self.training_type_plugin.pre_dispatch(trainer)

def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.dispatch(trainer)
self.training_type_plugin.precision_plugin.dispatch(trainer)

def post_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something after the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch(trainer)
self.training_type_plugin.precision_plugin.post_dispatch()

@property
def model(self) -> Module:
"""Returns the model.
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -
torch.nn.utils.clip_grad_norm_(parameters, clip_val)

def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something when ``Accelerator.dispatch()`` gets called."""
"""Hook to do something when ``TrainingTypePlugin.dispatch()`` gets called."""

def post_dispatch(self) -> None:
"""Hook to do something after the training/evaluation/prediction finishes."""
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def determine_ddp_device_ids(self):
return [self.root_device.index]

def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
# share ddp pids to all processes
self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts)
if self._should_run_deadlock_detection():
Expand All @@ -357,6 +358,7 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None:

def post_dispatch(self, trainer: "pl.Trainer") -> None:
self.cluster_environment.teardown()
super().post_dispatch(trainer)

def barrier(self, *args, **kwargs) -> None:
if not distributed_available():
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def restore_checkpoint_after_pre_dispatch(self) -> bool:
return True

def pre_dispatch(self, trainer: "pl.Trainer") -> None:
self._move_optimizer_state()
self.init_deepspeed()
self.barrier()

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def configure_ddp(self) -> None:
self.setup_optimizers(self.lightning_module.trainer)

def pre_dispatch(self, trainer: "pl.Trainer") -> None:
self._move_optimizer_state()
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
self.configure_ddp()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
super().setup(trainer)

def pre_dispatch(self, trainer: "pl.Trainer") -> None:

super().pre_dispatch(trainer)
if not self.lightning_module.trainer.training:
# no need to setup optimizers
return
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
raise MisconfigurationException("IPUs currently only support one optimizer.")

def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision)
self.model = model

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def model_to_device(self) -> None:
self.model.to(self.root_device)

def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
if isinstance(self.device, int):
self.device = xm.xla_device(self.device)

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def connect(self, model: "pl.LightningModule") -> None:
return super().connect(model)

def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,9 +459,12 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int =

def pre_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self._move_optimizer_state()

def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something at trainer run_stage starts."""
"""Hook to do something before the training/evaluation/prediction starts."""
self.precision_plugin.dispatch(trainer)

def post_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something after the training/evaluation/prediction finishes."""
"""Hook to do something after the training/evaluation/prediction starts."""
self.precision_plugin.post_dispatch()
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,7 @@ def _run(
return results

def _pre_dispatch(self):
self.accelerator.pre_dispatch(self)
self.training_type_plugin.pre_dispatch(self)
self._log_hyperparams()

def _log_hyperparams(self) -> None:
Expand Down Expand Up @@ -1224,7 +1224,7 @@ def _log_hyperparams(self) -> None:
self.logger.save()

def _post_dispatch(self):
self.accelerator.post_dispatch(self)
self.training_type_plugin.post_dispatch(self)
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
# which need to happen before.
self.accelerator.teardown()
Expand All @@ -1242,7 +1242,7 @@ def _dispatch(self) -> Any:
return self.training_type_plugin.start_training(self)

def run_stage(self):
self.accelerator.dispatch(self)
self.training_type_plugin.dispatch(self)
self.__setup_profiler()

if self.evaluating:
Expand Down

0 comments on commit 2fc64e9

Please sign in to comment.