diff --git a/CHANGELOG.md b/CHANGELOG.md index b629e9e72a9a5..6573a11692eb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -297,6 +297,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed NaN errors in progress bars when training with iterable datasets with no length defined ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306)) + + +- Fixed validation being skipped for iterable datasets with no length defined ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306)) + + - Fixed attaching train and validation dataloaders when `reload_dataloaders_every_epoch=True` and `num_sanity_val_steps=0` ([#7207](https://github.com/PyTorchLightning/pytorch-lightning/pull/7207)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index ac73efd67cb78..82839007b6851 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -231,9 +231,7 @@ def on_train_batch_end( self.save_checkpoint(trainer) def on_validation_end(self, trainer, pl_module) -> None: - """ - checkpoints can be saved at the end of the val loop - """ + """ Save a checkpoint at the end of the validation stage. """ skip = ( self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1 or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0 diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 5a76e5eb97331..be9d2f44356f5 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -20,6 +20,7 @@ """ import importlib import io +import math import os import sys @@ -397,7 +398,7 @@ def on_train_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches - if total_train_batches != float('inf'): + if total_train_batches != float('inf') and total_val_batches != float('inf'): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch @@ -407,7 +408,9 @@ def on_train_epoch_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches): + total_batches = self.total_train_batches + self.total_val_batches + total_batches = convert_inf(total_batches) + if self._should_update(self.train_batch_idx, total_batches): self._update_bar(self.main_progress_bar) self.main_progress_bar.set_postfix(trainer.progress_bar_dict) @@ -422,7 +425,7 @@ def on_validation_start(self, trainer, pl_module): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.val_batch_idx, self.total_val_batches): + if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)): self._update_bar(self.val_progress_bar) self._update_bar(self.main_progress_bar) @@ -479,7 +482,7 @@ def print( s = sep.join(map(str, args)) active_progress_bar.write(s, end=end, file=file, nolock=nolock) - def _should_update(self, current, total): + def _should_update(self, current, total) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def _update_bar(self, bar: Optional[tqdm]) -> None: @@ -496,8 +499,8 @@ def _update_bar(self, bar: Optional[tqdm]) -> None: def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: - """ The tqdm doesn't support inf values. We have to convert it to None. """ - if x == float('inf'): + """ The tqdm doesn't support inf/nan values. We have to convert it to None. """ + if x is None or math.isinf(x) or math.isnan(x): return None return x diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b34452d5cc7eb..f96c17a0686ce 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -484,9 +484,9 @@ def run_training_epoch(self): self.trainer.logger_connector.log_train_step_metrics(batch_output) # ----------------------------------------- - # VALIDATE IF NEEDED + CHECKPOINT CALLBACK + # VALIDATE IF NEEDED # ----------------------------------------- - should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) + should_check_val = self._should_check_val_fx(batch_idx, is_last_batch) if should_check_val: self.trainer.validating = True self.trainer.run_evaluation() @@ -535,7 +535,7 @@ def run_training_epoch(self): # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) - should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) + should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval @@ -825,19 +825,34 @@ def should_accumulate(self): is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) - def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): - # decide if we should run validation - is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 - is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 - can_check_val = self.trainer.enable_validation and is_val_check_epoch - is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") - epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 + def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: + """ Decide if we should run validation. """ + + if not self.trainer.enable_validation: + return False + + # check if this epoch is eligible to run validation + if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: + return False - should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop - or is_last_batch_for_infinite_dataset - ) if on_epoch else (is_val_check_batch and not epoch_end_val_check) + # val_check_batch is inf for iterable datasets with no length defined + # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch + is_val_check_batch = False + if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): + is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 + elif self.trainer.val_check_batch != float('inf'): + is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 - return should_check_val and can_check_val + # Note: num_training_batches is also inf for iterable datasets with no length defined + epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 + is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") + + if on_epoch: + return ( + is_val_check_batch and epoch_end_val_check + ) or self.trainer.should_stop or is_last_batch_for_infinite_dataset + else: + return is_val_check_batch and not epoch_end_val_check def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index 2e7c626306c36..eb81baeb2c29d 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -14,7 +14,7 @@ from typing import Optional import torch -from torch.utils.data import DataLoader, Dataset, Subset +from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset from pytorch_lightning import LightningDataModule, LightningModule @@ -60,6 +60,31 @@ def __len__(self): return self.len +class RandomIterableDataset(IterableDataset): + + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self): + for _ in range(self.count): + yield torch.randn(self.size) + + +class RandomIterableDatasetWithLen(IterableDataset): + + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self): + for _ in range(len(self)): + yield torch.randn(self.size) + + def __len__(self): + return self.count + + class BoringModel(LightningModule): def __init__(self): diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 240ddfa37b46e..6b0ea97d41a70 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate -from tests.helpers.boring_model import BoringModel, RandomDataset +from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset, RandomIterableDatasetWithLen from tests.helpers.runif import RunIf @@ -233,60 +233,212 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): assert len(trainer.test_dataloaders) == n +class DummyModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.log("loss", self.global_step) + return super().training_step(batch, batch_idx) + + def validation_epoch_end(self, outputs): + self.log("val_log", self.current_epoch) + + +class Counter(Callback): + + def __init__(self): + super().__init__() + self.train_epoch_count = 0 + self.val_epoch_count = 0 + self.test_epoch_count = 0 + self.train_batches_seen = 0 + self.val_batches_seen = 0 + self.test_batches_seen = 0 + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.train_batches_seen += 1 + + def on_train_epoch_start(self, trainer, pl_module): + self.train_epoch_count += 1 + + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.val_batches_seen += 1 + + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.test_batches_seen += 1 + + def on_validation_epoch_start(self, trainer, pl_module): + self.val_epoch_count += 1 + + def on_test_epoch_start(self, trainer, pl_module): + self.test_epoch_count += 1 + + @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ (0.0, 0.0, 0.0), (1.0, 1.0, 1.0), ]) def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent""" - model = EvalModelTemplate() - model.train_dataloader = model.train_dataloader__infinite - model.val_dataloader = model.val_dataloader__infinite - model.test_dataloader = model.test_dataloader__infinite + ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) + epoch_cb = Counter() trainer = Trainer( default_root_dir=tmpdir, + num_sanity_val_steps=0, max_epochs=1, + callbacks=[epoch_cb, ckpt_callback], limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, limit_test_batches=limit_test_batches, ) + model = DummyModel() - trainer.fit(model) + batch_size = 8 + train_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size) + val_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size) + test_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size) + + num_batches = 128 / batch_size + for dl in (train_dl, val_dl, test_dl): + if has_len(dl): + assert len(dl) == num_batches + else: + assert sum(1 for _ in dl) == num_batches + + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.num_training_batches == (0 if limit_train_batches == 0.0 else float('inf')) + assert epoch_cb.train_epoch_count == int(limit_train_batches > 0) assert trainer.num_val_batches[0] == (0 if limit_val_batches == 0.0 else float('inf')) + assert epoch_cb.val_epoch_count == int(limit_val_batches > 0) - trainer.test(ckpt_path=None) + trainer.test(model, test_dataloaders=test_dl) assert trainer.num_test_batches[0] == (0 if limit_test_batches == 0.0 else float('inf')) + assert epoch_cb.test_epoch_count == int(limit_test_batches > 0) -@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - (0, 0, 0), - (10, 10, 10), +@pytest.mark.parametrize(['dataset', 'limit_train_batches'], [ + (RandomDataset(32, 128), 0), + (RandomDataset(32, 128), 10), + (RandomIterableDataset(32, 128), 0), + (RandomIterableDataset(32, 128), 10), + (RandomIterableDatasetWithLen(32, 128), 0), + (RandomIterableDatasetWithLen(32, 128), 10), ]) -def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): +def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batches): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" - model = EvalModelTemplate() - model.train_dataloader = model.train_dataloader__infinite - model.val_dataloader = model.val_dataloader__infinite - model.test_dataloader = model.test_dataloader__infinite + ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) + epoch_cb = Counter() + epochs = 2 trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1, + num_sanity_val_steps=0, + max_epochs=epochs, + callbacks=[epoch_cb, ckpt_callback], + limit_train_batches=limit_train_batches, + ) + model = DummyModel() + + batch_size = 8 + train_dl = DataLoader(dataset=dataset, batch_size=batch_size) + val_dl = DataLoader(dataset=dataset, batch_size=batch_size) + + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.num_training_batches == limit_train_batches + assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0) + assert epoch_cb.train_batches_seen == limit_train_batches * epochs + + +@pytest.mark.parametrize( + ['dataset', 'limit_val_batches'], + [ + (RandomDataset(32, 128), 0), + (RandomDataset(32, 128), 10), + (RandomIterableDataset(32, 128), 0), + (RandomIterableDataset(32, 128), 10), + (RandomIterableDatasetWithLen(32, 128), 0), + # TODO: enable this after #6671 is merged + # (RandomIterableDatasetWithLen(32, 128), 10), + ] +) +def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches): + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" + + epoch_cb = Counter() + callbacks = [epoch_cb] + checkpoint_callback = True + if limit_val_batches > 0: + callbacks.append(ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False)) + else: + checkpoint_callback = False + epochs = 2 + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + max_epochs=epochs, + callbacks=callbacks, + limit_val_batches=limit_val_batches, + checkpoint_callback=checkpoint_callback, + ) + model = DummyModel() + + batch_size = 8 + train_dl = DataLoader(dataset=dataset, batch_size=batch_size) + val_dl = DataLoader(dataset=dataset, batch_size=batch_size) + + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.num_val_batches[0] == limit_val_batches + assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0) + assert epoch_cb.val_batches_seen == limit_val_batches * epochs + + +@pytest.mark.parametrize(['dataset', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ + (RandomDataset(32, 128), 0, 0, 0), + (RandomDataset(32, 128), 10, 10, 10), + (RandomIterableDataset(32, 128), 0, 0, 0), + (RandomIterableDataset(32, 128), 10, 10, 10), + (RandomIterableDatasetWithLen(32, 128), 0, 0, 0), + (RandomIterableDatasetWithLen(32, 128), 10, 10, 10), +]) +def test_datasets_dataloaders_with_limit_num_batches( + tmpdir, dataset, limit_train_batches, limit_val_batches, limit_test_batches +): + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" + + ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) + epoch_cb = Counter() + epochs = 2 + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + max_epochs=epochs, + callbacks=[epoch_cb, ckpt_callback], limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, limit_test_batches=limit_test_batches, ) + model = DummyModel() - trainer.fit(model) + batch_size = 8 + train_dl = DataLoader(dataset=dataset, batch_size=batch_size) + val_dl = DataLoader(dataset=dataset, batch_size=batch_size) + test_dl = DataLoader(dataset=dataset, batch_size=batch_size) + + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches[0] == limit_val_batches + assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0) + assert epoch_cb.train_batches_seen == limit_train_batches * epochs + assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0) + assert epoch_cb.val_batches_seen == limit_val_batches * epochs - trainer.test(ckpt_path=None) + trainer.test(model, test_dataloaders=test_dl) assert trainer.num_test_batches[0] == limit_test_batches + assert epoch_cb.test_epoch_count == int(limit_test_batches > 0) @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [