Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Fix dataloading for iterable datasets and limit_train_batches #7306

Merged
merged 22 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed


- Fixed NaN errors in progress bars when training with iterable datasets with no length defined ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306))


- Fixed validation being skipped for iterable datasets with no length defined ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306))


- Fixed attaching train and validation dataloaders when `reload_dataloaders_every_epoch=True` and `num_sanity_val_steps=0` ([#7207](https://github.com/PyTorchLightning/pytorch-lightning/pull/7207))


Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,7 @@ def on_train_batch_end(
self.save_checkpoint(trainer)

def on_validation_end(self, trainer, pl_module) -> None:
"""
checkpoints can be saved at the end of the val loop
"""
""" Save a checkpoint at the end of the validation stage. """
skip = (
self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1
or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0
Expand Down
15 changes: 9 additions & 6 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""
import importlib
import io
import math
import os
import sys

Expand Down Expand Up @@ -397,7 +398,7 @@ def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float('inf'):
if total_train_batches != float('inf') and total_val_batches != float('inf'):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
Expand All @@ -407,7 +408,9 @@ def on_train_epoch_start(self, trainer, pl_module):

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
total_batches = self.total_train_batches + self.total_val_batches
total_batches = convert_inf(total_batches)
if self._should_update(self.train_batch_idx, total_batches):
self._update_bar(self.main_progress_bar)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

Expand All @@ -422,7 +425,7 @@ def on_validation_start(self, trainer, pl_module):

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.val_batch_idx, self.total_val_batches):
if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)):
self._update_bar(self.val_progress_bar)
self._update_bar(self.main_progress_bar)

Expand Down Expand Up @@ -479,7 +482,7 @@ def print(
s = sep.join(map(str, args))
active_progress_bar.write(s, end=end, file=file, nolock=nolock)

def _should_update(self, current, total):
def _should_update(self, current, total) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def _update_bar(self, bar: Optional[tqdm]) -> None:
Expand All @@ -496,8 +499,8 @@ def _update_bar(self, bar: Optional[tqdm]) -> None:


def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
""" The tqdm doesn't support inf values. We have to convert it to None. """
if x == float('inf'):
""" The tqdm doesn't support inf/nan values. We have to convert it to None. """
if x is None or math.isinf(x) or math.isnan(x):
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
return None
return x

Expand Down
43 changes: 29 additions & 14 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,9 @@ def run_training_epoch(self):
self.trainer.logger_connector.log_train_step_metrics(batch_output)

# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# VALIDATE IF NEEDED
# -----------------------------------------
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch)
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch)
if should_check_val:
self.trainer.validating = True
self.trainer.run_evaluation()
Expand Down Expand Up @@ -535,7 +535,7 @@ def run_training_epoch(self):
# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)

should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval

Expand Down Expand Up @@ -825,19 +825,34 @@ def should_accumulate(self):
is_final_batch = self._num_training_batches_reached()
return not (accumulation_done or is_final_batch)

def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
can_check_val = self.trainer.enable_validation and is_val_check_epoch
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
""" Decide if we should run validation. """

if not self.trainer.enable_validation:
return False

# check if this epoch is eligible to run validation
if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0:
return False

should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop
or is_last_batch_for_infinite_dataset
) if on_epoch else (is_val_check_batch and not epoch_end_val_check)
# val_check_batch is inf for iterable datasets with no length defined
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = False
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this refactor!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaushikb11 thanks! it still feels complicated to me. part of that is from limit_train_batches / val_check_interval having different types and possible meanings depending on both depending on the user input and dataloader specified.

i'm wondering what's a better way to split "when to stop training mid-epoch" vs when to run validation or if a split is needed at all.

is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float('inf'):
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0

return should_check_val and can_check_val
# Note: num_training_batches is also inf for iterable datasets with no length defined
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")

if on_epoch:
return (
is_val_check_batch and epoch_end_val_check
) or self.trainer.should_stop or is_last_batch_for_infinite_dataset
else:
return is_val_check_batch and not epoch_end_val_check

def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
# enable not needing to add opt_idx to training_step
Expand Down
27 changes: 26 additions & 1 deletion tests/helpers/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Optional

import torch
from torch.utils.data import DataLoader, Dataset, Subset
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset

from pytorch_lightning import LightningDataModule, LightningModule

Expand Down Expand Up @@ -60,6 +60,31 @@ def __len__(self):
return self.len


class RandomIterableDataset(IterableDataset):

def __init__(self, size: int, count: int):
self.count = count
self.size = size

def __iter__(self):
for _ in range(self.count):
yield torch.randn(self.size)


class RandomIterableDatasetWithLen(IterableDataset):

def __init__(self, size: int, count: int):
self.count = count
self.size = size

def __iter__(self):
for _ in range(len(self)):
yield torch.randn(self.size)

def __len__(self):
return self.count


class BoringModel(LightningModule):

def __init__(self):
Expand Down
Loading