Skip to content

Commit

Permalink
[bugfix] Fix dataloading for iterable datasets and limit_train_batches (
Browse files Browse the repository at this point in the history
#7306)

* bugfix-dataloading

* rm-logs

* Update CHANGELOG.md

* Update test_dataloaders.py

* Update test_dataloaders.py

* Update training_loop.py

* Update test_dataloaders.py

* Update CHANGELOG.md

* Update CHANGELOG.md

* Update test_dataloaders.py

* Update training_loop.py

* Update training_loop.py

* comments

* address comments

* more tests

* Update progress.py

* Update test_dataloaders.py

* Update test_dataloaders.py

* Update training_loop.py

* Update training_loop.py

* test ckpt fix?

* update again
  • Loading branch information
ananthsub authored May 3, 2021
1 parent 7636d42 commit 14c552b
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 42 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed


- Fixed NaN errors in progress bars when training with iterable datasets with no length defined ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306))


- Fixed validation being skipped for iterable datasets with no length defined ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306))


- Fixed attaching train and validation dataloaders when `reload_dataloaders_every_epoch=True` and `num_sanity_val_steps=0` ([#7207](https://github.com/PyTorchLightning/pytorch-lightning/pull/7207))


Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,7 @@ def on_train_batch_end(
self.save_checkpoint(trainer)

def on_validation_end(self, trainer, pl_module) -> None:
"""
checkpoints can be saved at the end of the val loop
"""
""" Save a checkpoint at the end of the validation stage. """
skip = (
self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1
or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0
Expand Down
15 changes: 9 additions & 6 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""
import importlib
import io
import math
import os
import sys

Expand Down Expand Up @@ -397,7 +398,7 @@ def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float('inf'):
if total_train_batches != float('inf') and total_val_batches != float('inf'):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
Expand All @@ -407,7 +408,9 @@ def on_train_epoch_start(self, trainer, pl_module):

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
total_batches = self.total_train_batches + self.total_val_batches
total_batches = convert_inf(total_batches)
if self._should_update(self.train_batch_idx, total_batches):
self._update_bar(self.main_progress_bar)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

Expand All @@ -422,7 +425,7 @@ def on_validation_start(self, trainer, pl_module):

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.val_batch_idx, self.total_val_batches):
if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)):
self._update_bar(self.val_progress_bar)
self._update_bar(self.main_progress_bar)

Expand Down Expand Up @@ -479,7 +482,7 @@ def print(
s = sep.join(map(str, args))
active_progress_bar.write(s, end=end, file=file, nolock=nolock)

def _should_update(self, current, total):
def _should_update(self, current, total) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def _update_bar(self, bar: Optional[tqdm]) -> None:
Expand All @@ -496,8 +499,8 @@ def _update_bar(self, bar: Optional[tqdm]) -> None:


def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
""" The tqdm doesn't support inf values. We have to convert it to None. """
if x == float('inf'):
""" The tqdm doesn't support inf/nan values. We have to convert it to None. """
if x is None or math.isinf(x) or math.isnan(x):
return None
return x

Expand Down
43 changes: 29 additions & 14 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,9 @@ def run_training_epoch(self):
self.trainer.logger_connector.log_train_step_metrics(batch_output)

# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# VALIDATE IF NEEDED
# -----------------------------------------
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch)
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch)
if should_check_val:
self.trainer.validating = True
self.trainer.run_evaluation()
Expand Down Expand Up @@ -535,7 +535,7 @@ def run_training_epoch(self):
# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)

should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval

Expand Down Expand Up @@ -825,19 +825,34 @@ def should_accumulate(self):
is_final_batch = self._num_training_batches_reached()
return not (accumulation_done or is_final_batch)

def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
can_check_val = self.trainer.enable_validation and is_val_check_epoch
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
""" Decide if we should run validation. """

if not self.trainer.enable_validation:
return False

# check if this epoch is eligible to run validation
if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0:
return False

should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop
or is_last_batch_for_infinite_dataset
) if on_epoch else (is_val_check_batch and not epoch_end_val_check)
# val_check_batch is inf for iterable datasets with no length defined
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = False
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'):
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float('inf'):
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0

return should_check_val and can_check_val
# Note: num_training_batches is also inf for iterable datasets with no length defined
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")

if on_epoch:
return (
is_val_check_batch and epoch_end_val_check
) or self.trainer.should_stop or is_last_batch_for_infinite_dataset
else:
return is_val_check_batch and not epoch_end_val_check

def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
# enable not needing to add opt_idx to training_step
Expand Down
27 changes: 26 additions & 1 deletion tests/helpers/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Optional

import torch
from torch.utils.data import DataLoader, Dataset, Subset
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset

from pytorch_lightning import LightningDataModule, LightningModule

Expand Down Expand Up @@ -60,6 +60,31 @@ def __len__(self):
return self.len


class RandomIterableDataset(IterableDataset):

def __init__(self, size: int, count: int):
self.count = count
self.size = size

def __iter__(self):
for _ in range(self.count):
yield torch.randn(self.size)


class RandomIterableDatasetWithLen(IterableDataset):

def __init__(self, size: int, count: int):
self.count = count
self.size = size

def __iter__(self):
for _ in range(len(self)):
yield torch.randn(self.size)

def __len__(self):
return self.count


class BoringModel(LightningModule):

def __init__(self):
Expand Down
Loading

0 comments on commit 14c552b

Please sign in to comment.