Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify training phase as Enum #5419

Merged
merged 17 commits into from
Jan 13, 2021
27 changes: 26 additions & 1 deletion pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import DistributedType, DeviceType, rank_zero_warn


class DeprecatedDistDeviceAttributes:

_distrib_type: DistributedType
_device_type: DeviceType
_runing_stage: RunningStage
num_gpus: int

@property
Expand Down Expand Up @@ -129,3 +130,27 @@ def use_single_gpu(self, val: bool) -> None:
)
if val:
self._device_type = DeviceType.GPU

@property
def training(self) -> bool:
# todo: consider rename as `is_training`
return self._runing_stage == RunningStage.TRAINING
Borda marked this conversation as resolved.
Show resolved Hide resolved

@training.setter
def training(self, val: bool) -> None:
if val:
self._runing_stage = RunningStage.TRAINING
else:
self._runing_stage = None

@property
def testing(self) -> bool:
# todo: consider rename as `is_testing`
return self._runing_stage == RunningStage.TESTING

@testing.setter
def testing(self, val: bool) -> None:
if val:
self._runing_stage = RunningStage.TESTING
else:
self._runing_stage = None
18 changes: 15 additions & 3 deletions pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,37 @@
from typing import Callable, Optional

import pytorch_lightning
from pytorch_lightning.utilities import LightningEnum


class TrainerState(str, Enum):
class TrainerState(LightningEnum):
Borda marked this conversation as resolved.
Show resolved Hide resolved
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
to indicate what is currently or was executed.

>>> # you can math the type with string
Borda marked this conversation as resolved.
Show resolved Hide resolved
>>> TrainerState.RUNNING == 'RUNNING'
True
>>> # which is case sensitive
>>> # which is case insensitive
>>> TrainerState.FINISHED == 'finished'
False
True
"""
INITIALIZING = 'INITIALIZING'
RUNNING = 'RUNNING'
FINISHED = 'FINISHED'
INTERRUPTED = 'INTERRUPTED'


class RunningStage(LightningEnum):
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""Type of train phase.

>>> # you can match the Enum with string
>>> RunningStage.TRAINING == 'train'
True
"""
TRAINING = 'train'
TESTING = 'test'
Borda marked this conversation as resolved.
Show resolved Hide resolved


def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable:
""" Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods
which changes state to `entering` before the function execution and `exiting`
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def __init__(
super().__init__()
self._device_type = DeviceType.CPU
self._distrib_type = None
self._runing_stage = None

# init connectors
self.dev_debugger = InternalDebugger(self)
Expand Down
18 changes: 17 additions & 1 deletion tests/deprecated_api/test_remove_1-4.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_v1_4_0_deprecated_imports():
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils # noqa: F811 F401


def test_v1_4_0_deprecated_trainer_attributes():
def test_v1_4_0_deprecated_trainer_device_distrib():
"""Test that Trainer attributes works fine."""
trainer = Trainer()
trainer._distrib_type = None
Expand Down Expand Up @@ -76,6 +76,22 @@ def test_v1_4_0_deprecated_trainer_attributes():
assert trainer.use_horovod


def test_v1_4_0_deprecated_trainer_phase():
"""Test that Trainer attributes works fine."""
trainer = Trainer()

assert not trainer.training
assert not trainer.testing

trainer.training = True
assert trainer.training
assert not trainer.testing

trainer.testing = True
assert not trainer.training
assert trainer.testing


def test_v1_4_0_deprecated_metrics():
from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
with pytest.deprecated_call(match='will be removed in v1.4'):
Expand Down