Skip to content

Commit

Permalink
2/n Move Accelerator into strategy - move dispatch functionsto strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Dec 1, 2021
1 parent 619ef7a commit 558e4f5
Show file tree
Hide file tree
Showing 13 changed files with 36 additions and 37 deletions.
20 changes: 0 additions & 20 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,26 +71,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
"""
self.training_type_plugin.setup(trainer)

def pre_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin._move_optimizer_state()

self.training_type_plugin.pre_dispatch()
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
self.training_type_plugin.setup_optimizers(trainer)

self.training_type_plugin.precision_plugin.pre_dispatch()

def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.dispatch(trainer)
self.training_type_plugin.precision_plugin.dispatch(trainer)

def post_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something after the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch(trainer)
self.training_type_plugin.precision_plugin.post_dispatch()

@property
def model(self) -> Module:
"""Returns the model.
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def pre_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""

def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something when ``Accelerator.dispatch()`` gets called."""
"""Hook to do something when ``TrainingTypePlugin.dispatch()`` gets called."""

def post_dispatch(self) -> None:
"""Hook to do something after the training/evaluation/prediction finishes."""
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def determine_ddp_device_ids(self):
return None
return [self.root_device.index]

def pre_dispatch(self):
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
# share ddp pids to all processes
self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts)
if self._should_run_deadlock_detection():
Expand All @@ -355,8 +355,11 @@ def pre_dispatch(self):
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

super().pre_dispatch(trainer)

def post_dispatch(self, trainer: "pl.Trainer") -> None:
self.cluster_environment.teardown()
super().post_dispatch(trainer)

def barrier(self, *args, **kwargs) -> None:
if not distributed_available():
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()

def post_dispatch(self, trainer: "pl.Trainer"):
def post_dispatch(self, trainer: "pl.Trainer") -> None:
# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
Expand All @@ -230,6 +230,7 @@ def post_dispatch(self, trainer: "pl.Trainer"):

# recover the weights of the processes trained in the children
self.__recover_child_process_weights(best_path, last_path)
super().post_dispatch(trainer)

def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,10 @@ def _set_node_environment_variables(self) -> None:
def restore_checkpoint_after_pre_dispatch(self) -> bool:
return True

def pre_dispatch(self):
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
self.init_deepspeed()
self.barrier()
super().pre_dispatch(trainer)

def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Setup a model and multiple optimizers together.
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch

import pytorch_lightning as pl
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand Down Expand Up @@ -159,11 +160,12 @@ def configure_ddp(self) -> None:
# setup optimizers after fully sharded has wrapped the lightning module
self.setup_optimizers(self.lightning_module.trainer)

def pre_dispatch(self) -> None:
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
self.configure_ddp()
self.barrier()
super().pre_dispatch(trainer)

def model_to_device(self) -> None:
# ensure we update the device type in the lightning module
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.model_to_device()
super().setup(trainer)

def pre_dispatch(self):
def pre_dispatch(self, trainer: "pl.Trainer") -> None:

if not self.lightning_module.trainer.training:
# no need to setup optimizers
Expand Down Expand Up @@ -109,6 +109,7 @@ def _unpack_lightning_optimizer(opt):
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

self.optimizers = self._wrap_optimizers(optimizers)
super().pre_dispatch(trainer)

def start_training(self, trainer):
with ExitStack() as stack:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
if len(self.optimizers) > 1:
raise MisconfigurationException("IPUs currently only support one optimizer.")

def pre_dispatch(self) -> None:
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision)
self.model = model

Expand Down Expand Up @@ -161,6 +161,7 @@ def pre_dispatch(self) -> None:
elif trainer_fn == TrainerFn.PREDICTING:
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
self.poptorch_models[RunningStage.PREDICTING] = model
super().pre_dispatch(trainer)

@property
def replication_factor(self) -> int:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
def model_to_device(self) -> None:
self.model.to(self.root_device)

def pre_dispatch(self) -> None:
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
if isinstance(self.device, int):
self.device = xm.xla_device(self.device)

Expand All @@ -86,6 +86,7 @@ def pre_dispatch(self) -> None:

self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()
super().pre_dispatch(trainer)

def save(self, state_dict: Dict, path: _PATH) -> None:
xm.save(state_dict, path)
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,10 @@ def connect(self, model: "pl.LightningModule") -> None:
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
return super().connect(model)

def pre_dispatch(self):
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)
super().pre_dispatch(trainer)

def setup(self, trainer: "pl.Trainer") -> None:
self.create_mp_queue()
Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,18 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int =
"""Called in the training loop before anything happens for that batch."""
pass

def pre_dispatch(self) -> None:
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self._move_optimizer_state()
if self.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)

self.precision_plugin.pre_dispatch()

def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something at trainer run_stage starts."""
"""Hook to do something before the training/evaluation/prediction starts."""
self.precision_plugin.dispatch(trainer)

def post_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something after the training/evaluation/prediction finishes."""
"""Hook to do something after the training/evaluation/prediction starts."""
self.precision_plugin.post_dispatch()
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,7 @@ def _run(
return self.training_type_plugin.results

def _pre_dispatch(self):
self.accelerator.pre_dispatch(self)
self.training_type_plugin.pre_dispatch(self)
self._log_hyperparams()

def _log_hyperparams(self) -> None:
Expand Down Expand Up @@ -1224,7 +1224,7 @@ def _log_hyperparams(self) -> None:
self.logger.save()

def _post_dispatch(self):
self.accelerator.post_dispatch(self)
self.training_type_plugin.post_dispatch(self)
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
# which need to happen before.
self.accelerator.teardown()
Expand All @@ -1242,7 +1242,7 @@ def _dispatch(self):
self.training_type_plugin.start_training(self)

def run_stage(self):
self.accelerator.dispatch(self)
self.training_type_plugin.dispatch(self)
self.__setup_profiler()

if self.evaluating:
Expand Down
5 changes: 3 additions & 2 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.plugins import SingleDevicePlugin
Expand Down Expand Up @@ -55,8 +56,8 @@ def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatc
class TestPlugin(SingleDevicePlugin):
predispatched_called = False

def pre_dispatch(self) -> None:
super().pre_dispatch()
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
self.predispatched_called = True

@property
Expand Down

0 comments on commit 558e4f5

Please sign in to comment.