Skip to content

Commit

Permalink
remove training_step() from accelerator
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Dec 4, 2021
2 parents 6fe3211 + 45c28c3 commit 1f4df1e
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 16 deletions.
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,9 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

def training_step(self, *args, **kwargs) -> Optional[Any]:
return self.model(*args, **kwargs)
def training_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.train_step_context():
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
if isinstance(self.model, DistributedDataParallel):
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,9 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

def training_step(self, *args, **kwargs) -> Optional[Any]:
return self.model(*args, **kwargs)
def training_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.train_step_context():
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
if isinstance(self.model, DistributedDataParallel):
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _METRIC_COLLECTION
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, STEP_OUTPUT


class DataParallelPlugin(ParallelPlugin):
Expand Down Expand Up @@ -118,8 +118,9 @@ def broadcast(self, obj: object, src: int = 0) -> object:
def reduce_boolean_decision(self, decision: bool) -> bool:
return decision

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def training_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.train_step_context():
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def model_to_device(self) -> None:
self.lightning_module.to(self.root_device)

def training_step(self, *args, **kwargs):
return self.model.training_step(*args, **kwargs)
with self.precision_plugin.train_step_context():
return self.model.training_step(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model.validation_step(*args, **kwargs)
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _POPTORCH_AVAILABLE:
import poptorch
Expand Down Expand Up @@ -257,8 +258,9 @@ def _step(self, stage: RunningStage, *args: Any, **kwargs: Any):
self.lightning_module._running_torchscript = False
return out

def training_step(self, *args, **kwargs):
return self._step(RunningStage.TRAINING, *args, **kwargs)
def training_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.train_step_context():
return self._step(RunningStage.TRAINING, *args, **kwargs)

def validation_step(self, *args, **kwargs):
return self._step(RunningStage.VALIDATING, *args, **kwargs)
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,6 @@ def start_predicting(self, trainer: "pl.Trainer") -> Any:
self._clean_logger(trainer)
return super().start_predicting(trainer)

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.types import _PATH
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

TBroadcast = TypeVar("TBroadcast")

Expand Down Expand Up @@ -313,8 +313,13 @@ def start_predicting(self, trainer: "pl.Trainer") -> Any:
# double dispatch to initiate the predicting loop
return trainer.run_stage()

def training_step(self, *args, **kwargs):
return self.model.training_step(*args, **kwargs)
def training_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
"""The actual training step.
See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
"""
with self.precision_plugin.train_step_context():
return self.model.training_step(*args, **kwargs)

def post_training_step(self):
pass
Expand Down

0 comments on commit 1f4df1e

Please sign in to comment.