diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index ceb9d98505acc..60e6ea88b4250 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -85,7 +85,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None: model: the LightningModule """ self.setup_training_type_plugin(self.training_type_plugin, model) - self.setup_optimizers(trainer) + if not self.training_type_plugin.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) self.setup_precision_plugin(self.precision_plugin) def start_training(self, trainer: 'Trainer') -> None: @@ -97,12 +98,14 @@ def start_evaluating(self, trainer: 'Trainer') -> None: def start_predicting(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_predicting(trainer) - def pre_dispatch(self) -> None: + def pre_dispatch(self, trainer: 'Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.pre_dispatch() + if self.training_type_plugin.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) self.precision_plugin.pre_dispatch() - def post_dispatch(self) -> None: + def post_dispatch(self, trainer: 'Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch() self.precision_plugin.post_dispatch() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 6a87792c7bd03..b6f1be359bbf2 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -182,3 +182,13 @@ def init_optimizers(self, trainer: "Trainer", model: LightningModule): def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): optimizer.step(closure=lambda_closure, **kwargs) + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + """ + Override to delay setting optimizers and schedulers till after dispatch. + This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. + However this may break certain precision plugins such as APEX which require optimizers to be set. + Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. + """ + return False diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 53b4920bd85ef..0e9e28c9996f2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -495,7 +495,7 @@ def fit( return self.accelerator.results or 1 def pre_dispatch(self): - self.accelerator.pre_dispatch() + self.accelerator.pre_dispatch(self) # log hyper-parameters if self.logger is not None: @@ -505,7 +505,7 @@ def pre_dispatch(self): self.logger.save() def post_dispatch(self): - self.accelerator.post_dispatch() + self.accelerator.post_dispatch(self) self.accelerator.teardown() def dispatch(self): diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 81a5132e47356..349e4175a7444 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -2,11 +2,12 @@ import pytest import torch - +from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel def test_unsupported_precision_plugins(): @@ -18,3 +19,35 @@ def test_unsupported_precision_plugins(): ) with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."): accelerator.setup(trainer=trainer, model=model) + + +@pytest.mark.parametrize("delay_dispatch", [True, False]) +def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch): + """ + Test when using a custom training type plugin that delays setup optimizers, + we do not call setup optimizers till ``pre_dispatch``. + """ + + class TestModel(BoringModel): + def on_fit_start(self): + if delay_dispatch: + # Ensure we haven't setup optimizers if we've delayed dispatch + assert len(self.trainer.optimizers) == 0 + else: + assert len(self.trainer.optimizers) > 0 + + def on_fit_end(self): + assert len(self.trainer.optimizers) > 0 + + class CustomPlugin(SingleDevicePlugin): + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + return delay_dispatch + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins=CustomPlugin(device=torch.device("cpu")) + ) + trainer.fit(model)