Skip to content

Commit

Permalink
Merge 62fe5a4 into e038e74
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Mar 4, 2021
2 parents e038e74 + 62fe5a4 commit 0b6a659
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 6 deletions.
9 changes: 6 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
model: the model to train
"""
self.connect_training_type_plugin(self.training_type_plugin, model)
self.setup_optimizers(trainer)
if not self.training_type_plugin.setup_optimizers_after_dispatch:
self.setup_optimizers(trainer)
self.connect_precision_plugin(self.precision_plugin)

def start_training(self, trainer: 'Trainer') -> None:
Expand All @@ -86,12 +87,14 @@ def start_testing(self, trainer: 'Trainer') -> None:
def start_predicting(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_predicting(trainer)

def pre_dispatch(self) -> None:
def pre_dispatch(self, trainer: 'Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.pre_dispatch()
if self.training_type_plugin.setup_optimizers_after_dispatch:
self.setup_optimizers(trainer)
self.precision_plugin.pre_dispatch()

def post_dispatch(self) -> None:
def post_dispatch(self, trainer: 'Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,12 @@ def init_optimizers(self, trainer: "Trainer", model: LightningModule):

def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
optimizer.step(closure=lambda_closure, **kwargs)

@property
def setup_optimizers_after_dispatch(self) -> bool:
"""
Override to delay setting optimizers and schedulers till after dispatch.
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
Returns: True if delaying setup optimizers till after dispatch, False to call within setup.
"""
return False
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,10 @@ def fit(
return self.accelerator.results or 1

def pre_dispatch(self):
self.accelerator.pre_dispatch()
self.accelerator.pre_dispatch(self)

def post_dispatch(self):
self.accelerator.post_dispatch()
self.accelerator.post_dispatch(self)
self.accelerator.teardown()

def dispatch(self):
Expand Down
35 changes: 34 additions & 1 deletion tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from unittest.mock import Mock

import pytest
import pytorch_lightning as pl
import torch

from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.plugins import SingleDevicePlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel


def test_unsupported_precision_plugins():
Expand All @@ -18,3 +19,35 @@ def test_unsupported_precision_plugins():
)
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
accelerator.setup(trainer=trainer, model=model)


@pytest.mark.parametrize("delay_dispatch", [True, False])
def test_plugin_setup_optimizers_after_dispatch(tmpdir, delay_dispatch):
"""
Test when using a custom training type plugin that delays setup optimizers,
we do not call setup optimizers till after ``pre_dispatch``.
"""

class TestModel(BoringModel):
def on_fit_start(self):
if delay_dispatch:
# Ensure we haven't setup optimizers if we've delayed dispatch
assert len(self.trainer.optimizers) == 0
else:
assert len(self.trainer.optimizers) > 0

def on_fit_end(self):
assert len(self.trainer.optimizers) > 0

class CustomPlugin(SingleDevicePlugin):
@property
def setup_optimizers_after_dispatch(self) -> bool:
return delay_dispatch

model = TestModel()
trainer = pl.Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
plugins=CustomPlugin(device=torch.device("cpu"))
)
trainer.fit(model)

0 comments on commit 0b6a659

Please sign in to comment.