diff --git a/CHANGELOG.md b/CHANGELOG.md index 9398530102de7..47c3af94ff7f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- - +- Added support for setting `val_check_interval` to a value higher than the amount of training batches when `check_val_every_n_epoch=None` ([#11993](https://github.com/PyTorchLightning/pytorch-lightning/pull/11993)) - Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532)) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index b6887a4cf546c..ec02a099f6cbb 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -501,9 +501,9 @@ def _get_monitor_value(self, key: str) -> Any: return self.trainer.callback_metrics.get(key) def _should_check_val_epoch(self): - return ( - self.trainer.enable_validation - and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + return self.trainer.enable_validation and ( + self.trainer.check_val_every_n_epoch is None + or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 ) def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: @@ -524,7 +524,13 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float("inf"): - is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 + # if `check_val_every_n_epoch is `None`, run a validation loop every n training batches + # else condition it based on the batch_idx of the current epoch + current_iteration = ( + self._batches_that_stepped if self.trainer.check_val_every_n_epoch is None else batch_idx + ) + is_val_check_batch = (current_iteration + 1) % self.trainer.val_check_batch == 0 + return is_val_check_batch def _save_loggers_on_train_batch_end(self) -> None: diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index f1943073da819..978271f2a5d11 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -71,14 +71,21 @@ def _should_reload_val_dl(self) -> bool: def on_trainer_init( self, - check_val_every_n_epoch: int, + val_check_interval: Union[int, float], reload_dataloaders_every_n_epochs: int, + check_val_every_n_epoch: Optional[int], ) -> None: self.trainer.datamodule = None - if not isinstance(check_val_every_n_epoch, int): + if check_val_every_n_epoch is not None and not isinstance(check_val_every_n_epoch, int): raise MisconfigurationException( - f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}" + f"`check_val_every_n_epoch` should be an integer, found {check_val_every_n_epoch!r}." + ) + + if check_val_every_n_epoch is None and isinstance(val_check_interval, float): + raise MisconfigurationException( + "`val_check_interval` should be an integer when `check_val_every_n_epoch=None`," + f" found {val_check_interval!r}." ) self.trainer.check_val_every_n_epoch = check_val_every_n_epoch diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3d25075e2985e..189281627e5b9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -145,7 +145,7 @@ def __init__( enable_progress_bar: bool = True, overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, - check_val_every_n_epoch: int = 1, + check_val_every_n_epoch: Optional[int] = 1, fast_dev_run: Union[int, bool] = False, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None, max_epochs: Optional[int] = None, @@ -242,10 +242,11 @@ def __init__( :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``. - check_val_every_n_epoch: Check val every n train epochs. + check_val_every_n_epoch: Perform a validation loop every after every `N` training epochs. If ``None``, + validation will be done solely based on the number of training batches, requiring ``val_check_interval`` + to be an integer value. Default: ``1``. - default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' @@ -403,7 +404,8 @@ def __init__( val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training - batches. + batches. An ``int`` value can only be higher than the number of training batches when + ``check_val_every_n_epoch=None``. Default: ``1.0``. enable_model_summary: Whether to enable model summarization by default. @@ -524,8 +526,9 @@ def __init__( # init data flags self.check_val_every_n_epoch: int self._data_connector.on_trainer_init( - check_val_every_n_epoch, + val_check_interval, reload_dataloaders_every_n_epochs, + check_val_every_n_epoch, ) # gradient clipping @@ -1829,11 +1832,12 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval - if self.val_check_batch > self.num_training_batches: + if self.val_check_batch > self.num_training_batches and self.check_val_every_n_epoch is not None: raise ValueError( f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " f"to the number of the training batches ({self.num_training_batches}). " "If you want to disable validation set `limit_val_batches` to 0.0 instead." + "If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`." ) else: if not has_len_all_ranks(self.train_dataloader, self.strategy, module): diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py index 97c6ddf7803ab..ca2537b829cd7 100644 --- a/tests/trainer/flags/test_check_val_every_n_epoch.py +++ b/tests/trainer/flags/test_check_val_every_n_epoch.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +from torch.utils.data import DataLoader -from pytorch_lightning.trainer import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.trainer.trainer import Trainer +from tests.helpers import BoringModel, RandomDataset @pytest.mark.parametrize( @@ -46,3 +47,35 @@ def on_validation_epoch_start(self) -> None: assert model.val_epoch_calls == expected_val_loop_calls assert model.val_batches == expected_val_batches + + +def test_check_val_every_n_epoch_with_max_steps(tmpdir): + data_samples_train = 2 + check_val_every_n_epoch = 3 + max_epochs = 4 + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.validation_called_at_step = set() + + def validation_step(self, *args): + self.validation_called_at_step.add(self.global_step) + return super().validation_step(*args) + + def train_dataloader(self): + return DataLoader(RandomDataset(32, data_samples_train)) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=data_samples_train * max_epochs, + check_val_every_n_epoch=check_val_every_n_epoch, + num_sanity_val_steps=0, + ) + + trainer.fit(model) + + assert trainer.current_epoch == max_epochs + assert trainer.global_step == max_epochs * data_samples_train + assert list(model.validation_called_at_step) == [data_samples_train * check_val_every_n_epoch] diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index 685e104805daa..b575faa81203c 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -14,9 +14,12 @@ import logging import pytest +from torch.utils.data import DataLoader -from pytorch_lightning.trainer import Trainer -from tests.helpers import BoringModel +from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel, RandomDataset +from tests.helpers.boring_model import RandomIterableDataset @pytest.mark.parametrize("max_epochs", [1, 2, 3]) @@ -57,3 +60,66 @@ def test_val_check_interval_info_message(caplog, value): with caplog.at_level(logging.INFO): Trainer() assert message not in caplog.text + + +@pytest.mark.parametrize("use_infinite_dataset", [True, False]) +def test_validation_check_interval_exceed_data_length_correct(tmpdir, use_infinite_dataset): + data_samples_train = 4 + max_epochs = 3 + max_steps = data_samples_train * max_epochs + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.validation_called_at_step = set() + + def validation_step(self, *args): + self.validation_called_at_step.add(self.global_step) + return super().validation_step(*args) + + def train_dataloader(self): + train_ds = ( + RandomIterableDataset(32, count=max_steps + 100) + if use_infinite_dataset + else RandomDataset(32, length=data_samples_train) + ) + return DataLoader(train_ds) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_val_batches=1, + max_steps=max_steps, + val_check_interval=3, + check_val_every_n_epoch=None, + num_sanity_val_steps=0, + ) + + trainer.fit(model) + + assert trainer.current_epoch == 1 if use_infinite_dataset else max_epochs + assert trainer.global_step == max_steps + assert sorted(list(model.validation_called_at_step)) == [3, 6, 9, 12] + + +def test_validation_check_interval_exceed_data_length_wrong(): + trainer = Trainer( + limit_train_batches=10, + val_check_interval=100, + ) + + model = BoringModel() + with pytest.raises(ValueError, match="must be less than or equal to the number of the training batches"): + trainer.fit(model) + + +def test_val_check_interval_float_with_none_check_val_every_n_epoch(): + """Test that an exception is raised when `val_check_interval` is set to float with + `check_val_every_n_epoch=None`""" + with pytest.raises( + MisconfigurationException, match="`val_check_interval` should be an integer when `check_val_every_n_epoch=None`" + ): + Trainer( + val_check_interval=0.5, + check_val_every_n_epoch=None, + )