Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support val_check_interval values higher than number of training batches #11993

Merged
merged 18 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-

- Support setting `val_check_interval` to a value higher than the amount of training batches when `check_val_every_n_epoch=None` ([#8135](https://github.com/PyTorchLightning/pytorch-lightning/issues/8135))

-

Expand Down
25 changes: 20 additions & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,18 @@ def _get_monitor_value(self, key: str) -> Any:
return self.trainer.callback_metrics.get(key)

def _should_check_val_epoch(self):
return (
self.trainer.enable_validation
and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
)
if not self.trainer.enable_validation:
return False

# first we check if `check_val_every_n_epoch is `None`, which means
nikvaessen marked this conversation as resolved.
Show resolved Hide resolved
# that we run a validation loop after n global steps (taken from the
# Trainer argument `val_check_interval`)
if self.trainer.check_val_every_n_epoch is None:
return (self.trainer.global_step + 1) % self.trainer.val_check_batch == 0
nikvaessen marked this conversation as resolved.
Show resolved Hide resolved

# If it's not `None`, we respect running a validation loop after every n epochs
else:
return (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0

def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
"""Decide if we should run validation."""
Expand All @@ -524,7 +532,14 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float("inf"):
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
# if we're checking based on global step, we can start validation
# at any point in the training epoch
if self.trainer.check_val_every_n_epoch is None:
is_val_check_batch = True
else:
# TODO: clarify the purpose of this check.
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0

return is_val_check_batch

def _save_loggers_on_train_batch_end(self) -> None:
Expand Down
11 changes: 9 additions & 2 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,23 @@ def _should_reload_val_dl(self) -> bool:

def on_trainer_init(
self,
check_val_every_n_epoch: int,
val_check_interval: Union[int, float],
reload_dataloaders_every_n_epochs: int,
check_val_every_n_epoch: Optional[int] = None,
nikvaessen marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self.trainer.datamodule = None

if not isinstance(check_val_every_n_epoch, int):
if check_val_every_n_epoch is not None and not isinstance(check_val_every_n_epoch, int):
raise MisconfigurationException(
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
)

if check_val_every_n_epoch is None and isinstance(val_check_interval, float):
raise MisconfigurationException(
f"val_check_interval should be an integer when check_val_every_n_epoch={check_val_every_n_epoch}. "
f"Found val_check_interval={val_check_interval}"
nikvaessen marked this conversation as resolved.
Show resolved Hide resolved
)

self.trainer.check_val_every_n_epoch = check_val_every_n_epoch

if not isinstance(reload_dataloaders_every_n_epochs, int) or (reload_dataloaders_every_n_epochs < 0):
Expand Down
18 changes: 11 additions & 7 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
enable_progress_bar: bool = True,
overfit_batches: Union[int, float] = 0.0,
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
check_val_every_n_epoch: Optional[int] = 1,
fast_dev_run: Union[int, bool] = False,
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
max_epochs: Optional[int] = None,
Expand Down Expand Up @@ -247,10 +247,10 @@ def __init__(
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.
Default: ``True``.

check_val_every_n_epoch: Check val every n train epochs.
check_val_every_n_epoch: Perform a validation loop every after every `n` train epochs. If `None`, validation
will be done solely based on the number of steps, requiring `val_check_interval` to be an integer value.
Default: ``1``.


default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
Default: ``os.getcwd()``.
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
Expand Down Expand Up @@ -408,7 +408,8 @@ def __init__(

val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
batches.
batches. An ``int`` value can only be higher than the amount of batches in the training set when
`check_val_every_n_epoch=None`, otherwise the validation set is never checked.
Default: ``1.0``.

enable_model_summary: Whether to enable model summarization by default.
Expand Down Expand Up @@ -531,8 +532,9 @@ def __init__(
# init data flags
self.check_val_every_n_epoch: int
self._data_connector.on_trainer_init(
check_val_every_n_epoch,
val_check_interval,
reload_dataloaders_every_n_epochs,
check_val_every_n_epoch,
)

if terminate_on_nan is not None:
Expand Down Expand Up @@ -1883,11 +1885,13 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -

if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
if self.val_check_batch > self.num_training_batches and self.check_val_every_n_epoch is not None:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
f"to the number of the training batches ({self.num_training_batches}). "
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
"If you want to disable validation set `limit_val_batches` to 0.0 instead. "
"If you want to validate based on the step count instead of the epoch count, "
"set `check_val_every_n_epoch=None`."
)
else:
if not has_len_all_ranks(self.train_dataloader, self.strategy, module):
Expand Down
114 changes: 113 additions & 1 deletion tests/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.
import pytest
import torch
from torch.utils.data import DataLoader

from pytorch_lightning import seed_everything, Trainer
from tests.helpers import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset


def test_outputs_format(tmpdir):
Expand Down Expand Up @@ -151,3 +153,113 @@ def training_step_end(self, outputs):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)

trainer.fit(model)


def test_validation_check_interval_exceed_data_length_correct(tmpdir):
batch_size = 32
data_samples_train = 10
data_samples_val = 1

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.validation_called_at_step = set()

def training_step(self, batch, batch_idx):
return super().training_step(batch, batch_idx)

def validation_step(self, *args):
self.validation_called_at_step.add(int(self.trainer.global_step))
return super().validation_step(*args)

def train_dataloader(self):
return DataLoader(RandomDataset(batch_size, data_samples_train))

def val_dataloader(self):
return DataLoader(RandomDataset(batch_size, data_samples_val))

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=data_samples_train * 3,
val_check_interval=15,
check_val_every_n_epoch=None,
num_sanity_val_steps=0,
)

print("\ncalling trainer.fit")
trainer.fit(model)

# with a data length of 10, a val_check_interval of 15, and max_steps=30, we
# should have validated twice
assert trainer.current_epoch == 3
assert trainer.global_step == 30
assert sorted(list(model.validation_called_at_step)) == [14, 29]


def test_validation_check_interval_exceed_data_length_wrong(tmpdir):
model = BoringModel()

with pytest.raises(ValueError):
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=200,
val_check_interval=100,
check_val_every_n_epoch=1,
num_sanity_val_steps=0,
)
trainer.fit(model)


def test_validation_check_interval_float_wrong(tmpdir):
model = BoringModel()

with pytest.raises(MisconfigurationException):
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=200,
val_check_interval=0.5,
check_val_every_n_epoch=None,
num_sanity_val_steps=0,
)
trainer.fit(model)


def test_validation_loop_every_5_epochs(tmpdir):
batch_size = 32
data_samples_train = 10
data_samples_val = 1

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.validation_called_at_step = set()

def training_step(self, batch, batch_idx):
return super().training_step(batch, batch_idx)

def validation_step(self, *args):
self.validation_called_at_step.add(int(self.trainer.global_step))
return super().validation_step(*args)

def train_dataloader(self):
return DataLoader(RandomDataset(batch_size, data_samples_train))

def val_dataloader(self):
return DataLoader(RandomDataset(batch_size, data_samples_val))

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=data_samples_train * 9,
check_val_every_n_epoch=5,
num_sanity_val_steps=0,
)

trainer.fit(model)

# with a data length of 10, validation every 5 epochs, and max_steps=90, we should
# validate once
assert trainer.current_epoch == 9
assert trainer.global_step == 90
assert list(model.validation_called_at_step) == [50]