Skip to content

Commit

Permalink
Merge branch 'release/1.2-dev' into fix/yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Jan 24, 2021
2 parents ce298e7 + 6386f45 commit b8eda67
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _optim_state_dict(self, optimizer):

def _wrap_optimizers(self, model):
trainer = model.trainer
if trainer.testing is True:
if trainer.testing:
return

self._reinit_with_fairscale_oss(trainer)
Expand Down
24 changes: 0 additions & 24 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,3 @@ def use_single_gpu(self, val: bool) -> None:
)
if val:
self._device_type = DeviceType.GPU

@property
def training(self) -> bool:
# todo: consider rename as `is_training`
return self._running_stage == RunningStage.TRAINING

@training.setter
def training(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.TRAINING
else:
self._running_stage = None

@property
def testing(self) -> bool:
# todo: consider rename as `is_testing`
return self._running_stage == RunningStage.TESTING

@testing.setter
def testing(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.TESTING
else:
self._running_stage = None
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def on_trainer_init(self):
self.trainer.test_dataloaders = None
self.trainer.val_dataloaders = None
self.trainer.running_sanity_check = False
self.trainer.testing = False

# when .test() is called, it sets this
self.trainer.tested_ckpt_path = None
Expand Down
46 changes: 45 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.properties import TrainerProperties
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
Expand Down Expand Up @@ -921,3 +921,47 @@ def available_plugins():
Returns: List of all available plugins that are supported as string arguments.
"""
return PluginConnector.available_plugins()

@property
def training(self) -> bool:
return self._running_stage == RunningStage.TRAINING

@training.setter
def training(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.TRAINING
elif self.training:
self._running_stage = None

@property
def testing(self) -> bool:
return self._running_stage == RunningStage.TESTING

@testing.setter
def testing(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.TESTING
elif self.testing:
self._running_stage = None

@property
def tuning(self) -> bool:
return self._running_stage == RunningStage.TUNING

@tuning.setter
def tuning(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.TUNING
elif self.tuning:
self._running_stage = None

@property
def evaluating(self) -> bool:
return self._running_stage == RunningStage.EVALUATING

@evaluating.setter
def evaluating(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.EVALUATING
elif self.evaluating:
self._running_stage = None
16 changes: 0 additions & 16 deletions tests/deprecated_api/test_remove_1-4.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,6 @@ def test_v1_4_0_deprecated_trainer_device_distrib():
assert trainer.use_horovod


def test_v1_4_0_deprecated_trainer_phase():
"""Test that Trainer attributes works fine."""
trainer = Trainer()

assert not trainer.training
assert not trainer.testing

trainer.training = True
assert trainer.training
assert not trainer.testing

trainer.testing = True
assert not trainer.training
assert trainer.testing


def test_v1_4_0_deprecated_metrics():
from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
with pytest.deprecated_call(match='will be removed in v1.4'):
Expand Down

0 comments on commit b8eda67

Please sign in to comment.