Skip to content

Commit

Permalink
Support val_check_interval values higher than number of training batc…
Browse files Browse the repository at this point in the history
  • Loading branch information
eladsegal committed Jun 10, 2022
1 parent 03873e5 commit ee5a30e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
10 changes: 8 additions & 2 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
if not self.trainer.enable_validation:
return False

is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
is_val_check_epoch = self.trainer.check_val_every_n_epoch is None or ((self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0)
if not is_val_check_epoch:
return False

Expand All @@ -498,7 +498,13 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float("inf"):
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
# if `check_val_every_n_epoch is `None`, run a validation loop every n training batches
# else condition it based on the batch_idx of the current epoch
current_iteration = (
self._batches_that_stepped if self.trainer.check_val_every_n_epoch is None else batch_idx
)
is_val_check_batch = (current_iteration + 1) % self.trainer.val_check_batch == 0

return is_val_check_batch

def _save_loggers_on_train_batch_end(self) -> None:
Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ def _should_reload_val_dl(self) -> bool:

def on_trainer_init(
self,
check_val_every_n_epoch: int,
val_check_interval: Union[int, float],
reload_dataloaders_every_n_epochs: int,
check_val_every_n_epoch: Optional[int],
prepare_data_per_node: Optional[bool] = None,
) -> None:
self.trainer.datamodule = None
Expand All @@ -85,9 +86,15 @@ def on_trainer_init(
)
self.trainer.prepare_data_per_node = prepare_data_per_node

if not isinstance(check_val_every_n_epoch, int):
if check_val_every_n_epoch is not None and not isinstance(check_val_every_n_epoch, int):
raise MisconfigurationException(
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
f"`check_val_every_n_epoch` should be an integer, found {check_val_every_n_epoch!r}."
)

if check_val_every_n_epoch is None and isinstance(val_check_interval, float):
raise MisconfigurationException(
"`val_check_interval` should be an integer when `check_val_every_n_epoch=None`,"
f" found {val_check_interval!r}."
)

self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
Expand Down
15 changes: 10 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(
enable_progress_bar: bool = True,
overfit_batches: Union[int, float] = 0.0,
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
check_val_every_n_epoch: Optional[int] = 1,
fast_dev_run: Union[int, bool] = False,
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
max_epochs: Optional[int] = None,
Expand Down Expand Up @@ -248,7 +248,9 @@ def __init__(
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.
Default: ``True``.
check_val_every_n_epoch: Check val every n train epochs.
check_val_every_n_epoch: Perform a validation loop every after every `N` training epochs. If ``None``,
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
to be an integer value.
Default: ``1``.
Expand Down Expand Up @@ -427,7 +429,8 @@ def __init__(
val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
batches.
batches. An ``int`` value can only be higher than the number of training batches when
``check_val_every_n_epoch=None``.
Default: ``1.0``.
enable_model_summary: Whether to enable model summarization by default.
Expand Down Expand Up @@ -561,8 +564,9 @@ def __init__(
# init data flags
self.check_val_every_n_epoch: int
self._data_connector.on_trainer_init(
check_val_every_n_epoch,
val_check_interval.
reload_dataloaders_every_n_epochs,
check_val_every_n_epoch,
prepare_data_per_node,
)

Expand Down Expand Up @@ -1862,11 +1866,12 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -

if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
if self.val_check_batch > self.num_training_batches and self.check_val_every_n_epoch is not None:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
f"to the number of the training batches ({self.num_training_batches}). "
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
"If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`."
)
else:
if not has_len_all_ranks(self.train_dataloader, self.strategy, module):
Expand Down

0 comments on commit ee5a30e

Please sign in to comment.