Skip to content

Commit

Permalink
mini refactor for _running_stage access (#5724)
Browse files Browse the repository at this point in the history
* running stage

* circular import

* running stage cleanup

* fix unused import

* fix running stage access

* add return type

* Revert "add return type"

This reverts commit 65b0fe2.

* try fix typing
  • Loading branch information
awaelchli authored Feb 22, 2021
1 parent 423ecf9 commit 0456b45
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 27 deletions.
10 changes: 8 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

import torch
from torch import ScriptModule, Tensor
Expand All @@ -44,6 +44,9 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args

if TYPE_CHECKING:
from pytorch_lightning.trainer.states import RunningStage


class LightningModule(
ABC,
Expand Down Expand Up @@ -103,7 +106,6 @@ def __init__(self, *args, **kwargs):
self._running_manual_backward = False
self._current_hook_fx_name = None
self._current_dataloader_idx = None
self.running_stage = None
self._automatic_optimization: bool = True

def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
Expand Down Expand Up @@ -169,6 +171,10 @@ def automatic_optimization(self) -> bool:
"""
return self._automatic_optimization

@property
def running_stage(self) -> Optional["RunningStage"]:
return self.trainer._running_stage if self.trainer else None

@automatic_optimization.setter
def automatic_optimization(self, automatic_optimization: bool) -> None:
self._automatic_optimization = automatic_optimization
Expand Down
27 changes: 8 additions & 19 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -450,7 +449,7 @@ def fit(
# bookkeeping
# we reuse fit in .test() and .predict(). When already set, it shouldn't be modified.
if self._running_stage is None:
self._set_running_stage(RunningStage.TRAINING, model)
self._running_stage = RunningStage.TRAINING

# set local properties on the model
self.model_connector.copy_trainer_model_properties(model)
Expand Down Expand Up @@ -531,7 +530,7 @@ def fit(
if self._state != TrainerState.INTERRUPTED:
self._state = TrainerState.FINISHED

self._set_running_stage(None, model)
self._running_stage = None

return self.accelerator.results or 1

Expand Down Expand Up @@ -564,14 +563,6 @@ def train_or_test_or_predict(self):

return results

def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule):
"""
This function is used to set the running_state on both
the trainer and the model
"""
model_ref.running_stage = stage
self._running_stage = stage

def _pre_training_routine(self):
# wait for all to join if on distributed
self.accelerator.barrier("setup_training")
Expand Down Expand Up @@ -614,7 +605,7 @@ def run_train(self):
self.run_sanity_check(self.lightning_module)

# set stage for logging
self._set_running_stage(RunningStage.TRAINING, self.lightning_module)
self._running_stage = RunningStage.TRAINING

self.checkpoint_connector.has_trained = False

Expand Down Expand Up @@ -678,9 +669,7 @@ def run_train(self):
def run_evaluation(self, max_batches=None, on_epoch=False):

# used to know if we are logging for val, test + reset cached results
self._set_running_stage(
RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.lightning_module
)
self._running_stage = RunningStage.TESTING if self.testing else RunningStage.EVALUATING
self.logger_connector.reset()

# bookkeeping
Expand Down Expand Up @@ -907,7 +896,7 @@ def test(
# --------------------
self.verbose_test = verbose

self._set_running_stage(RunningStage.TESTING, model or self.lightning_module)
self._running_stage = RunningStage.TESTING

# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if test_dataloaders and datamodule:
Expand All @@ -924,7 +913,7 @@ def test(
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)

self.teardown('test')
self._set_running_stage(None, model or self.lightning_module)
self._running_stage = None
return results

def __test_using_best_weights(self, ckpt_path, test_dataloaders):
Expand Down Expand Up @@ -1016,7 +1005,7 @@ def predict(

model = model or self.lightning_module

self._set_running_stage(RunningStage.PREDICTING, model)
self._running_stage = RunningStage.PREDICTING

if dataloaders and datamodule:
raise MisconfigurationException(
Expand All @@ -1033,7 +1022,7 @@ def predict(

self.model = model
results = self.fit(model)
self._set_running_stage(None, model)
self._running_stage = None

return results

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def run_training_epoch(self):
self.trainer.run_evaluation()

# reset stage to train
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
self.trainer._running_stage = RunningStage.TRAINING

# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
Expand Down Expand Up @@ -564,7 +564,7 @@ def run_training_epoch(self):
self.trainer.run_evaluation(on_epoch=True)

# reset stage to train
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
self.trainer._running_stage = RunningStage.TRAINING

should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def on_train_start(self):
# haven't trained with the new loaded model
dp_model = new_trainer.model
dp_model.eval()
dp_model.module.module.running_stage = RunningStage.EVALUATING
new_trainer._running_stage = RunningStage.EVALUATING

dataloader = self.train_dataloader()
tpipes.run_prediction(self.trainer.lightning_module, dataloader)
Expand Down
8 changes: 5 additions & 3 deletions tests/overrides/test_data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock

import pytest
import torch
Expand Down Expand Up @@ -103,7 +103,8 @@ def training_step(self, batch, batch_idx):
return {"loss": loss}

model = TestModel()
model.running_stage = RunningStage.TRAINING
model.trainer = Mock()
model.trainer._running_stage = RunningStage.TRAINING
batch = torch.rand(2, 32).cuda()
batch_idx = 0

Expand Down Expand Up @@ -146,7 +147,8 @@ def training_step(self, batch, batch_idx):

model = TestModel()
model.to(device)
model.running_stage = RunningStage.TRAINING
model.trainer = Mock()
model.trainer._running_stage = RunningStage.TRAINING
batch = torch.rand(2, 32).to(device)
batch_idx = 0

Expand Down

0 comments on commit 0456b45

Please sign in to comment.