From 0456b4598f5f7eaebf626bca45d563562a15887b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 22 Feb 2021 12:01:54 +0100 Subject: [PATCH] mini refactor for _running_stage access (#5724) * running stage * circular import * running stage cleanup * fix unused import * fix running stage access * add return type * Revert "add return type" This reverts commit 65b0fe269c6547213e34b6a88b97bee31cdfe8c7. * try fix typing --- pytorch_lightning/core/lightning.py | 10 ++++++-- pytorch_lightning/trainer/trainer.py | 27 +++++++--------------- pytorch_lightning/trainer/training_loop.py | 4 ++-- tests/models/test_restore.py | 2 +- tests/overrides/test_data_parallel.py | 8 ++++--- 5 files changed, 24 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 78ee40d81d45a..57aa264244a68 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -24,7 +24,7 @@ from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch from torch import ScriptModule, Tensor @@ -44,6 +44,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args +if TYPE_CHECKING: + from pytorch_lightning.trainer.states import RunningStage + class LightningModule( ABC, @@ -103,7 +106,6 @@ def __init__(self, *args, **kwargs): self._running_manual_backward = False self._current_hook_fx_name = None self._current_dataloader_idx = None - self.running_stage = None self._automatic_optimization: bool = True def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: @@ -169,6 +171,10 @@ def automatic_optimization(self) -> bool: """ return self._automatic_optimization + @property + def running_stage(self) -> Optional["RunningStage"]: + return self.trainer._running_stage if self.trainer else None + @automatic_optimization.setter def automatic_optimization(self, automatic_optimization: bool) -> None: self._automatic_optimization = automatic_optimization diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 10545a075cb32..cf3bfd7a3e5a3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -59,7 +59,6 @@ from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger -from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden @@ -450,7 +449,7 @@ def fit( # bookkeeping # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified. if self._running_stage is None: - self._set_running_stage(RunningStage.TRAINING, model) + self._running_stage = RunningStage.TRAINING # set local properties on the model self.model_connector.copy_trainer_model_properties(model) @@ -531,7 +530,7 @@ def fit( if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED - self._set_running_stage(None, model) + self._running_stage = None return self.accelerator.results or 1 @@ -564,14 +563,6 @@ def train_or_test_or_predict(self): return results - def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule): - """ - This function is used to set the running_state on both - the trainer and the model - """ - model_ref.running_stage = stage - self._running_stage = stage - def _pre_training_routine(self): # wait for all to join if on distributed self.accelerator.barrier("setup_training") @@ -614,7 +605,7 @@ def run_train(self): self.run_sanity_check(self.lightning_module) # set stage for logging - self._set_running_stage(RunningStage.TRAINING, self.lightning_module) + self._running_stage = RunningStage.TRAINING self.checkpoint_connector.has_trained = False @@ -678,9 +669,7 @@ def run_train(self): def run_evaluation(self, max_batches=None, on_epoch=False): # used to know if we are logging for val, test + reset cached results - self._set_running_stage( - RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.lightning_module - ) + self._running_stage = RunningStage.TESTING if self.testing else RunningStage.EVALUATING self.logger_connector.reset() # bookkeeping @@ -907,7 +896,7 @@ def test( # -------------------- self.verbose_test = verbose - self._set_running_stage(RunningStage.TESTING, model or self.lightning_module) + self._running_stage = RunningStage.TESTING # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: @@ -924,7 +913,7 @@ def test( results = self.__test_using_best_weights(ckpt_path, test_dataloaders) self.teardown('test') - self._set_running_stage(None, model or self.lightning_module) + self._running_stage = None return results def __test_using_best_weights(self, ckpt_path, test_dataloaders): @@ -1016,7 +1005,7 @@ def predict( model = model or self.lightning_module - self._set_running_stage(RunningStage.PREDICTING, model) + self._running_stage = RunningStage.PREDICTING if dataloaders and datamodule: raise MisconfigurationException( @@ -1033,7 +1022,7 @@ def predict( self.model = model results = self.fit(model) - self._set_running_stage(None, model) + self._running_stage = None return results diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9d10a1f67c5dc..d2298c8c4e860 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -517,7 +517,7 @@ def run_training_epoch(self): self.trainer.run_evaluation() # reset stage to train - self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module) + self.trainer._running_stage = RunningStage.TRAINING # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -564,7 +564,7 @@ def run_training_epoch(self): self.trainer.run_evaluation(on_epoch=True) # reset stage to train - self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module) + self.trainer._running_stage = RunningStage.TRAINING should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index d28ab6177f21c..a3f88e37bb09a 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -453,7 +453,7 @@ def on_train_start(self): # haven't trained with the new loaded model dp_model = new_trainer.model dp_model.eval() - dp_model.module.module.running_stage = RunningStage.EVALUATING + new_trainer._running_stage = RunningStage.EVALUATING dataloader = self.train_dataloader() tpipes.run_prediction(self.trainer.lightning_module, dataloader) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 64481bd70390d..90bb6fac88457 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest import torch @@ -103,7 +103,8 @@ def training_step(self, batch, batch_idx): return {"loss": loss} model = TestModel() - model.running_stage = RunningStage.TRAINING + model.trainer = Mock() + model.trainer._running_stage = RunningStage.TRAINING batch = torch.rand(2, 32).cuda() batch_idx = 0 @@ -146,7 +147,8 @@ def training_step(self, batch, batch_idx): model = TestModel() model.to(device) - model.running_stage = RunningStage.TRAINING + model.trainer = Mock() + model.trainer._running_stage = RunningStage.TRAINING batch = torch.rand(2, 32).to(device) batch_idx = 0