Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support val_check_interval values higher than number of training batches #11993

Merged
merged 18 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-

- Added support for setting `val_check_interval` to a value higher than the amount of training batches when `check_val_every_n_epoch=None` ([#11993](https://github.com/PyTorchLightning/pytorch-lightning/pull/11993))

- Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532))

Expand Down
14 changes: 10 additions & 4 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,9 @@ def _get_monitor_value(self, key: str) -> Any:
return self.trainer.callback_metrics.get(key)

def _should_check_val_epoch(self):
return (
self.trainer.enable_validation
and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
return self.trainer.enable_validation and (
self.trainer.check_val_every_n_epoch is None
or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
)

def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
Expand All @@ -524,7 +524,13 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float("inf"):
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
# if `check_val_every_n_epoch is `None`, run a validation loop every n training batches
# else condition it based on the batch_idx of the current epoch
current_iteration = (
self._batches_that_stepped if self.trainer.check_val_every_n_epoch is None else batch_idx
)
is_val_check_batch = (current_iteration + 1) % self.trainer.val_check_batch == 0

return is_val_check_batch

def _save_loggers_on_train_batch_end(self) -> None:
Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,21 @@ def _should_reload_val_dl(self) -> bool:

def on_trainer_init(
self,
check_val_every_n_epoch: int,
val_check_interval: Union[int, float],
reload_dataloaders_every_n_epochs: int,
check_val_every_n_epoch: Optional[int],
) -> None:
self.trainer.datamodule = None

if not isinstance(check_val_every_n_epoch, int):
if check_val_every_n_epoch is not None and not isinstance(check_val_every_n_epoch, int):
raise MisconfigurationException(
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
f"`check_val_every_n_epoch` should be an integer, found {check_val_every_n_epoch!r}."
)

if check_val_every_n_epoch is None and isinstance(val_check_interval, float):
raise MisconfigurationException(
"`val_check_interval` should be an integer when `check_val_every_n_epoch=None`,"
f" found {val_check_interval!r}."
)

self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
Expand Down
16 changes: 10 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
enable_progress_bar: bool = True,
overfit_batches: Union[int, float] = 0.0,
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
check_val_every_n_epoch: Optional[int] = 1,
fast_dev_run: Union[int, bool] = False,
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
max_epochs: Optional[int] = None,
Expand Down Expand Up @@ -242,10 +242,11 @@ def __init__(
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.
Default: ``True``.

check_val_every_n_epoch: Check val every n train epochs.
check_val_every_n_epoch: Perform a validation loop every after every `N` training epochs. If ``None``,
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
to be an integer value.
Default: ``1``.


default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
Default: ``os.getcwd()``.
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
Expand Down Expand Up @@ -403,7 +404,8 @@ def __init__(

val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
batches.
batches. An ``int`` value can only be higher than the number of training batches when
``check_val_every_n_epoch=None``.
Default: ``1.0``.

enable_model_summary: Whether to enable model summarization by default.
Expand Down Expand Up @@ -524,8 +526,9 @@ def __init__(
# init data flags
self.check_val_every_n_epoch: int
self._data_connector.on_trainer_init(
check_val_every_n_epoch,
val_check_interval,
reload_dataloaders_every_n_epochs,
check_val_every_n_epoch,
)

# gradient clipping
Expand Down Expand Up @@ -1829,11 +1832,12 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -

if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
if self.val_check_batch > self.num_training_batches and self.check_val_every_n_epoch is not None:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
f"to the number of the training batches ({self.num_training_batches}). "
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
"If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`."
)
else:
if not has_len_all_ranks(self.train_dataloader, self.strategy, module):
Expand Down
39 changes: 37 additions & 2 deletions tests/trainer/flags/test_check_val_every_n_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from torch.utils.data import DataLoader

from pytorch_lightning.trainer import Trainer
from tests.helpers import BoringModel
from pytorch_lightning.trainer.trainer import Trainer
from tests.helpers import BoringModel, RandomDataset


@pytest.mark.parametrize(
Expand Down Expand Up @@ -46,3 +47,37 @@ def on_validation_epoch_start(self) -> None:

assert model.val_epoch_calls == expected_val_loop_calls
assert model.val_batches == expected_val_batches


def test_check_val_every_n_epoch_with_max_steps(tmpdir):
data_samples_train = 2
check_val_every_n_epoch = 3
max_epochs = 4

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.validation_called_at_step = set()

def validation_step(self, *args):
self.validation_called_at_step.add(self.global_step)
return super().validation_step(*args)

def train_dataloader(self):
return DataLoader(RandomDataset(32, data_samples_train))

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=data_samples_train * max_epochs,
check_val_every_n_epoch=check_val_every_n_epoch,
num_sanity_val_steps=0,
)

trainer.fit(model)

# with a data length of 10, validation every 5 epochs, and max_steps=90, we should
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
# validate once
assert trainer.current_epoch == max_epochs
assert trainer.global_step == max_epochs * data_samples_train
assert list(model.validation_called_at_step) == [data_samples_train * check_val_every_n_epoch]
73 changes: 71 additions & 2 deletions tests/trainer/flags/test_val_check_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import logging

import pytest
from torch.utils.data import DataLoader

from pytorch_lightning.trainer import Trainer
from tests.helpers import BoringModel
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.boring_model import RandomIterableDataset


@pytest.mark.parametrize("max_epochs", [1, 2, 3])
Expand Down Expand Up @@ -57,3 +60,69 @@ def test_val_check_interval_info_message(caplog, value):
with caplog.at_level(logging.INFO):
Trainer()
assert message not in caplog.text


@pytest.mark.parametrize("use_infinite_dataset", [True, False])
def test_validation_check_interval_exceed_data_length_correct(tmpdir, use_infinite_dataset):
data_samples_train = 4
max_epochs = 3
max_steps = data_samples_train * max_epochs

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.validation_called_at_step = set()

def validation_step(self, *args):
self.validation_called_at_step.add(self.global_step)
return super().validation_step(*args)

def train_dataloader(self):
train_ds = (
RandomIterableDataset(32, count=max_steps + 100)
if use_infinite_dataset
else RandomDataset(32, length=data_samples_train)
)
return DataLoader(train_ds)

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_val_batches=1,
max_steps=max_steps,
val_check_interval=3,
check_val_every_n_epoch=None,
num_sanity_val_steps=0,
)

trainer.fit(model)

assert trainer.current_epoch == 1 if use_infinite_dataset else max_epochs
assert trainer.global_step == max_steps

# with a data length of 10 (or infinite), a val_check_interval of 15, and max_steps=30,
# we should have validated twice
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
assert sorted(list(model.validation_called_at_step)) == [3, 6, 9, 12]


def test_validation_check_interval_exceed_data_length_wrong():
trainer = Trainer(
limit_train_batches=10,
val_check_interval=100,
)

model = BoringModel()
with pytest.raises(ValueError, match="must be less than or equal to the number of the training batches"):
trainer.fit(model)


def test_val_check_interval_float_with_none_check_val_every_n_epoch():
"""Test that an exception is raised with `val_check_interval` is set to float with
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
`check_val_every_n_epoch=None`"""
with pytest.raises(
MisconfigurationException, match="`val_check_interval` should be an integer when `check_val_every_n_epoch=None`"
):
Trainer(
val_check_interval=0.5,
check_val_every_n_epoch=None,
)