Skip to content

Commit

Permalink
Support val_check_interval values higher than number of training ba…
Browse files Browse the repository at this point in the history
…tches (#11993)

Co-authored-by: rohitgr7 <[email protected]>
  • Loading branch information
nikvaessen and rohitgr7 authored Apr 21, 2022
1 parent f300b60 commit a758d90
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 19 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-

- Added support for setting `val_check_interval` to a value higher than the amount of training batches when `check_val_every_n_epoch=None` ([#11993](https://github.com/PyTorchLightning/pytorch-lightning/pull/11993))

- Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532))

Expand Down
14 changes: 10 additions & 4 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,9 @@ def _get_monitor_value(self, key: str) -> Any:
return self.trainer.callback_metrics.get(key)

def _should_check_val_epoch(self):
return (
self.trainer.enable_validation
and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
return self.trainer.enable_validation and (
self.trainer.check_val_every_n_epoch is None
or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
)

def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
Expand All @@ -524,7 +524,13 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float("inf"):
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
# if `check_val_every_n_epoch is `None`, run a validation loop every n training batches
# else condition it based on the batch_idx of the current epoch
current_iteration = (
self._batches_that_stepped if self.trainer.check_val_every_n_epoch is None else batch_idx
)
is_val_check_batch = (current_iteration + 1) % self.trainer.val_check_batch == 0

return is_val_check_batch

def _save_loggers_on_train_batch_end(self) -> None:
Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,21 @@ def _should_reload_val_dl(self) -> bool:

def on_trainer_init(
self,
check_val_every_n_epoch: int,
val_check_interval: Union[int, float],
reload_dataloaders_every_n_epochs: int,
check_val_every_n_epoch: Optional[int],
) -> None:
self.trainer.datamodule = None

if not isinstance(check_val_every_n_epoch, int):
if check_val_every_n_epoch is not None and not isinstance(check_val_every_n_epoch, int):
raise MisconfigurationException(
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
f"`check_val_every_n_epoch` should be an integer, found {check_val_every_n_epoch!r}."
)

if check_val_every_n_epoch is None and isinstance(val_check_interval, float):
raise MisconfigurationException(
"`val_check_interval` should be an integer when `check_val_every_n_epoch=None`,"
f" found {val_check_interval!r}."
)

self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
Expand Down
16 changes: 10 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
enable_progress_bar: bool = True,
overfit_batches: Union[int, float] = 0.0,
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
check_val_every_n_epoch: Optional[int] = 1,
fast_dev_run: Union[int, bool] = False,
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
max_epochs: Optional[int] = None,
Expand Down Expand Up @@ -242,10 +242,11 @@ def __init__(
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.
Default: ``True``.
check_val_every_n_epoch: Check val every n train epochs.
check_val_every_n_epoch: Perform a validation loop every after every `N` training epochs. If ``None``,
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
to be an integer value.
Default: ``1``.
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
Default: ``os.getcwd()``.
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
Expand Down Expand Up @@ -403,7 +404,8 @@ def __init__(
val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
batches.
batches. An ``int`` value can only be higher than the number of training batches when
``check_val_every_n_epoch=None``.
Default: ``1.0``.
enable_model_summary: Whether to enable model summarization by default.
Expand Down Expand Up @@ -524,8 +526,9 @@ def __init__(
# init data flags
self.check_val_every_n_epoch: int
self._data_connector.on_trainer_init(
check_val_every_n_epoch,
val_check_interval,
reload_dataloaders_every_n_epochs,
check_val_every_n_epoch,
)

# gradient clipping
Expand Down Expand Up @@ -1829,11 +1832,12 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -

if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
if self.val_check_batch > self.num_training_batches and self.check_val_every_n_epoch is not None:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
f"to the number of the training batches ({self.num_training_batches}). "
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
"If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`."
)
else:
if not has_len_all_ranks(self.train_dataloader, self.strategy, module):
Expand Down
37 changes: 35 additions & 2 deletions tests/trainer/flags/test_check_val_every_n_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from torch.utils.data import DataLoader

from pytorch_lightning.trainer import Trainer
from tests.helpers import BoringModel
from pytorch_lightning.trainer.trainer import Trainer
from tests.helpers import BoringModel, RandomDataset


@pytest.mark.parametrize(
Expand Down Expand Up @@ -46,3 +47,35 @@ def on_validation_epoch_start(self) -> None:

assert model.val_epoch_calls == expected_val_loop_calls
assert model.val_batches == expected_val_batches


def test_check_val_every_n_epoch_with_max_steps(tmpdir):
data_samples_train = 2
check_val_every_n_epoch = 3
max_epochs = 4

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.validation_called_at_step = set()

def validation_step(self, *args):
self.validation_called_at_step.add(self.global_step)
return super().validation_step(*args)

def train_dataloader(self):
return DataLoader(RandomDataset(32, data_samples_train))

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=data_samples_train * max_epochs,
check_val_every_n_epoch=check_val_every_n_epoch,
num_sanity_val_steps=0,
)

trainer.fit(model)

assert trainer.current_epoch == max_epochs
assert trainer.global_step == max_epochs * data_samples_train
assert list(model.validation_called_at_step) == [data_samples_train * check_val_every_n_epoch]
70 changes: 68 additions & 2 deletions tests/trainer/flags/test_val_check_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import logging

import pytest
from torch.utils.data import DataLoader

from pytorch_lightning.trainer import Trainer
from tests.helpers import BoringModel
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.boring_model import RandomIterableDataset


@pytest.mark.parametrize("max_epochs", [1, 2, 3])
Expand Down Expand Up @@ -57,3 +60,66 @@ def test_val_check_interval_info_message(caplog, value):
with caplog.at_level(logging.INFO):
Trainer()
assert message not in caplog.text


@pytest.mark.parametrize("use_infinite_dataset", [True, False])
def test_validation_check_interval_exceed_data_length_correct(tmpdir, use_infinite_dataset):
data_samples_train = 4
max_epochs = 3
max_steps = data_samples_train * max_epochs

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.validation_called_at_step = set()

def validation_step(self, *args):
self.validation_called_at_step.add(self.global_step)
return super().validation_step(*args)

def train_dataloader(self):
train_ds = (
RandomIterableDataset(32, count=max_steps + 100)
if use_infinite_dataset
else RandomDataset(32, length=data_samples_train)
)
return DataLoader(train_ds)

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_val_batches=1,
max_steps=max_steps,
val_check_interval=3,
check_val_every_n_epoch=None,
num_sanity_val_steps=0,
)

trainer.fit(model)

assert trainer.current_epoch == 1 if use_infinite_dataset else max_epochs
assert trainer.global_step == max_steps
assert sorted(list(model.validation_called_at_step)) == [3, 6, 9, 12]


def test_validation_check_interval_exceed_data_length_wrong():
trainer = Trainer(
limit_train_batches=10,
val_check_interval=100,
)

model = BoringModel()
with pytest.raises(ValueError, match="must be less than or equal to the number of the training batches"):
trainer.fit(model)


def test_val_check_interval_float_with_none_check_val_every_n_epoch():
"""Test that an exception is raised when `val_check_interval` is set to float with
`check_val_every_n_epoch=None`"""
with pytest.raises(
MisconfigurationException, match="`val_check_interval` should be an integer when `check_val_every_n_epoch=None`"
):
Trainer(
val_check_interval=0.5,
check_val_every_n_epoch=None,
)

0 comments on commit a758d90

Please sign in to comment.