Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable checkpointing, earlystopping and logging with fast_dev_run #5277

Merged
merged 12 commits into from
Jan 5, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


### Fixed
- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277))



Expand Down
7 changes: 6 additions & 1 deletion docs/source/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@ The point is to detect any bugs in the training/validation loop without having t
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)

.. testcode::

# runs 1 train, val, test batch and program ends
trainer = Trainer(fast_dev_run=True)

# runs 7 train, val, test batches and program ends
trainer = Trainer(fast_dev_run=7)

.. note::

This argument will disable tuner, checkpoint callbacks, early stopping callbacks,
loggers and logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch.

----------------

Inspect gradient norms
Expand Down
6 changes: 3 additions & 3 deletions docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,9 @@ Under the hood the pseudocode looks like this when running *fast_dev_run* with a
.. note::

This argument is a bit different from ``limit_train/val/test_batches``. Setting this argument will
disable tuner, logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch. This must be
used only for debugging purposes. ``limit_train/val/test_batches`` only limits the number of batches and won't
disable anything.
disable tuner, checkpoint callbacks, early stopping callbacks, loggers and logger callbacks like
``LearningRateLogger`` and runs for only 1 epoch. This must be used only for debugging purposes.
``limit_train/val/test_batches`` only limits the number of batches and won't disable anything.

gpus
^^^^
Expand Down
23 changes: 9 additions & 14 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, TPU_AVAILABLE


class EarlyStopping(Callback):
Expand Down Expand Up @@ -166,10 +166,10 @@ def on_validation_end(self, trainer, pl_module):
self._run_early_stopping_check(trainer, pl_module)

def on_validation_epoch_end(self, trainer, pl_module):
if trainer.running_sanity_check:
if trainer.fast_dev_run or trainer.running_sanity_check:
return

if self._validate_condition_metric(trainer.logger_connector.callback_metrics):
if self._validate_condition_metric(trainer.callback_metrics):
# turn off early stopping in on_train_epoch_end
self.based_on_eval_results = True

Expand All @@ -178,24 +178,19 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
if self.based_on_eval_results:
return

# early stopping can also work in the train loop when there is no val loop
should_check_early_stop = False

# fallback to monitor key in result dict
if trainer.logger_connector.callback_metrics.get(self.monitor, None) is not None:
should_check_early_stop = True

if should_check_early_stop:
self._run_early_stopping_check(trainer, pl_module)
self._run_early_stopping_check(trainer, pl_module)

def _run_early_stopping_check(self, trainer, pl_module):
"""
Checks whether the early stopping condition is met
and if so tells the trainer to stop the training.
"""
logs = trainer.logger_connector.callback_metrics
logs = trainer.callback_metrics

if not self._validate_condition_metric(logs):
if (
trainer.fast_dev_run # disable early_stopping with fast_dev_run
or not self._validate_condition_metric(logs) # short circuit if metric not present
):
return # short circuit if metric not present

current = logs.get(self.monitor)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import shutil
import subprocess
import time
from typing import List, Tuple, Dict
from typing import Dict, List, Tuple

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_only
Expand Down Expand Up @@ -213,5 +213,4 @@ def _should_log(trainer) -> bool:
or trainer.should_stop
)

should_log = should_log and not trainer.fast_dev_run
return should_log
5 changes: 2 additions & 3 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def on_train_batch_start(self, trainer, *args, **kwargs):
interval = 'step' if self.logging_interval is None else 'any'
latest_stat = self._extract_stats(trainer, interval)

if trainer.logger is not None and latest_stat:
if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)

def on_train_epoch_start(self, trainer, *args, **kwargs):
if self.logging_interval != 'step':
interval = 'epoch' if self.logging_interval is None else 'any'
latest_stat = self._extract_stats(trainer, interval)

if trainer.logger is not None and latest_stat:
if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)

def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
Expand Down Expand Up @@ -190,5 +190,4 @@ def _should_log(trainer) -> bool:
or trainer.should_stop
)

should_log = should_log and not trainer.fast_dev_run
return should_log
7 changes: 4 additions & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def save_checkpoint(self, trainer, pl_module):
global_step = trainer.global_step

if (
self.save_top_k == 0 # no models are saved
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or self.save_top_k == 0 # no models are saved
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
Expand Down Expand Up @@ -478,14 +479,14 @@ def __resolve_ckpt_dir(self, trainer, pl_module):
version, name = trainer.accelerator_backend.broadcast((version, trainer.logger.name))

ckpt_path = os.path.join(
save_dir, name, version, "checkpoints"
save_dir, str(name), version, "checkpoints"
)
else:
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")

self.dirpath = ckpt_path

if trainer.is_global_zero:
if not trainer.fast_dev_run and trainer.is_global_zero:
self._fs.makedirs(self.dirpath, exist_ok=True)

def _add_backward_monitor_support(self, trainer):
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import importlib
import sys


# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
if importlib.util.find_spec('ipywidgets') is not None:
Expand Down Expand Up @@ -323,7 +322,7 @@ def on_epoch_start(self, trainer, pl_module):
super().on_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float('inf') and not trainer.fast_dev_run:
if total_train_batches != float('inf'):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/trainer/connectors/debugging_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from typing import Union
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info

from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class DebuggingConnector:
Expand Down Expand Up @@ -54,11 +56,16 @@ def on_init_start(
limit_train_batches = fast_dev_run
limit_val_batches = fast_dev_run
limit_test_batches = fast_dev_run
self.trainer.max_steps = fast_dev_run
self.trainer.num_sanity_val_steps = 0
self.trainer.max_epochs = 1
self.trainer.val_check_interval = 1.0
self.trainer.check_val_every_n_epoch = 1
self.trainer.logger = DummyLogger()

rank_zero_info(
'Running in fast_dev_run mode: will run a full train,'
f' val and test loop using {fast_dev_run} batch(es)'
f' val and test loop using {fast_dev_run} batch(es).'
)

self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
Expand Down
38 changes: 29 additions & 9 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from abc import ABC
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Type, TypeVar, Union, cast
import inspect
import os
from typing import cast, List, Optional, Type, TypeVar, Union

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.loggers.base import LightningLoggerBase
Expand All @@ -27,7 +27,7 @@
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, TPU_AVAILABLE, argparse_utils, rank_zero_warn
from pytorch_lightning.utilities import argparse_utils, HOROVOD_AVAILABLE, rank_zero_warn, TPU_AVAILABLE
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.model_utils import is_overridden

Expand Down Expand Up @@ -196,7 +196,7 @@ def enable_validation(self) -> bool:
""" Check if we should run validation during training. """
model_ref = self.model_connector.get_model()
val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0
return val_loop_enabled or self.fast_dev_run
return val_loop_enabled

@property
def default_root_dir(self) -> str:
Expand All @@ -218,18 +218,38 @@ def weights_save_path(self) -> str:
return os.path.normpath(self._weights_save_path)
return self._weights_save_path

@property
def early_stopping_callback(self) -> Optional[EarlyStopping]:
"""
The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.
"""
callbacks = self.early_stopping_callbacks
return callbacks[0] if len(callbacks) > 0 else None

@property
def early_stopping_callbacks(self) -> List[EarlyStopping]:
"""
A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
found in the Trainer.callbacks list.
"""
return [c for c in self.callbacks if isinstance(c, EarlyStopping)]

@property
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
"""
The first checkpoint callback in the Trainer.callbacks list, or ``None`` if
no checkpoint callbacks exist.
The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.
"""
callbacks = self.checkpoint_callbacks
return callbacks[0] if len(callbacks) > 0 else None

@property
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
""" A list of all instances of ModelCheckpoint found in the Trainer.callbacks list. """
"""
A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
found in the Trainer.callbacks list.
"""
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]

def save_checkpoint(self, filepath, weights_only: bool = False):
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,9 +915,8 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
def save_loggers_on_train_batch_end(self):
# when loggers should save to disk
should_flush_logs = self.trainer.logger_connector.should_flush_logs
if should_flush_logs or self.trainer.fast_dev_run is True:
if self.trainer.is_global_zero and self.trainer.logger is not None:
self.trainer.logger.save()
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
self.trainer.logger.save()

def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator):
"""
Expand Down
12 changes: 7 additions & 5 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@
# limitations under the License.
import os
import pickle
from unittest import mock

import cloudpickle
import numpy as np
import pytest
import torch
from unittest import mock

from pytorch_lightning import _logger
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import _logger, seed_everything, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from tests.base import EvalModelTemplate, BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel, EvalModelTemplate


class EarlyStoppingTestRestore(EarlyStopping):
Expand Down Expand Up @@ -87,15 +86,18 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
def test_early_stopping_no_extraneous_invocations(tmpdir):
"""Test to ensure that callback methods aren't being invoked outside of the callback handler."""
model = EvalModelTemplate()
early_stop_callback = EarlyStopping()
expected_count = 4
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[EarlyStopping()],
callbacks=[early_stop_callback],
val_check_interval=1.0,
max_epochs=expected_count,
)
trainer.fit(model)

assert trainer.early_stopping_callback == early_stop_callback
assert trainer.early_stopping_callbacks == [early_stop_callback]
assert len(trainer.dev_debugger.early_stopping_history) == expected_count


Expand Down
47 changes: 1 addition & 46 deletions tests/checkpointing/test_checkpoint_callback_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,10 @@
import pytest
import torch

from pytorch_lightning import Trainer, callbacks, seed_everything
from pytorch_lightning import callbacks, seed_everything, Trainer
from tests.base import BoringModel


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_mc_called_on_fastdevrun(tmpdir):
seed_everything(1234)

train_val_step_model = BoringModel()

# fast dev run = called once
# train loop only, dict, eval result
trainer = Trainer(fast_dev_run=True)
trainer.fit(train_val_step_model)

# checkpoint should have been called once with fast dev run
assert len(trainer.dev_debugger.checkpoint_callback_history) == 1

# -----------------------
# also called once with no val step
# -----------------------
class TrainingStepCalled(BoringModel):
def __init__(self):
super().__init__()
self.training_step_called = False
self.validation_step_called = False
self.test_step_called = False

def training_step(self, batch, batch_idx):
self.training_step_called = True
return super().training_step(batch, batch_idx)

train_step_only_model = TrainingStepCalled()
train_step_only_model.validation_step = None

# fast dev run = called once
# train loop only, dict, eval result
trainer = Trainer(fast_dev_run=True)
trainer.fit(train_step_only_model)

# make sure only training step was called
assert train_step_only_model.training_step_called
assert not train_step_only_model.validation_step_called
assert not train_step_only_model.test_step_called

# checkpoint should have been called once with fast dev run
assert len(trainer.dev_debugger.checkpoint_callback_history) == 1


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_mc_called(tmpdir):
seed_everything(1234)
Expand Down
Loading