From 51d44f57dd8de1822faa81a12eb7ca631ddfe504 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 16 Feb 2023 15:10:16 +0100 Subject: [PATCH] Prefetch if it's not a sized iterable (#16776) --- src/lightning/pytorch/loops/dataloader/evaluation_loop.py | 8 +------- src/lightning/pytorch/loops/fetchers.py | 8 ++++++-- src/lightning/pytorch/loops/fit_loop.py | 7 +------ src/lightning/pytorch/loops/utilities.py | 4 ++-- tests/tests_pytorch/loops/test_fetchers.py | 5 +++-- 5 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/lightning/pytorch/loops/dataloader/evaluation_loop.py b/src/lightning/pytorch/loops/dataloader/evaluation_loop.py index 24914733e3ec7..3e30da9adaeb1 100644 --- a/src/lightning/pytorch/loops/dataloader/evaluation_loop.py +++ b/src/lightning/pytorch/loops/dataloader/evaluation_loop.py @@ -74,12 +74,6 @@ def dataloaders(self) -> Sequence[DataLoader]: return [] return dataloaders - @property - def prefetch_batches(self) -> int: - batches = self.trainer.num_test_batches if self.trainer.testing else self.trainer.num_val_batches - is_unsized = batches[self.current_dataloader_idx] == float("inf") - return int(is_unsized) - @property def done(self) -> bool: """Returns whether all dataloaders are processed or evaluation should be skipped altogether.""" @@ -126,7 +120,7 @@ def reset(self) -> None: def on_run_start(self) -> None: """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks.""" - self._data_fetcher = _select_data_fetcher(self.trainer, prefetch_batches=self.prefetch_batches) + self._data_fetcher = _select_data_fetcher(self.trainer) # hook self._on_evaluation_model_eval() diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index 8d983b6e91f04..b7385eb6b5f1b 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -81,7 +81,7 @@ class _PrefetchDataFetcher(_DataFetcher): Args: prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track - whether a batch is the last one (available with :attr:`self.done`) under any training setup. + whether a batch is the last one (available with :attr:`self.done`) when the length is not available. """ def __init__(self, prefetch_batches: int = 1) -> None: @@ -98,6 +98,10 @@ def setup(self, dataloader: Iterable) -> None: def __iter__(self) -> "_PrefetchDataFetcher": super().__iter__() + if self._has_len: + # ignore pre-fetching, it's not necessary + return self + # prefetch batches to know when the iterator will be exhausted in advance iterator = self.dataloader_iter assert iterator is not None for _ in range(self.prefetch_batches): @@ -143,7 +147,7 @@ def _fetch_next_batch(self, iterator: Iterator) -> None: finally: self._stop_profiler() self.fetched += 1 - if not self.prefetch_batches and self._has_len: + if self._has_len: # when we don't prefetch but the dataloader is sized, we use the length for `done` dataloader = self.dataloader assert isinstance(dataloader, Sized) # `_has_len` is True diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 73fc7a37175e5..ef9f44c13f465 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -121,11 +121,6 @@ def restarting(self, restarting: bool) -> None: restarting = restarting and epoch_unfinished or self._iteration_based_training() _Loop.restarting.fset(self, restarting) # call the parent setter - @property - def prefetch_batches(self) -> int: - is_unsized = self.trainer.num_training_batches == float("inf") - return int(is_unsized) - @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" @@ -219,7 +214,7 @@ def on_run_start(self) -> None: if self.epoch_loop._should_check_val_epoch(): self.epoch_loop.val_loop._reload_evaluation_dataloaders() - self._data_fetcher = _select_data_fetcher(trainer, self.prefetch_batches) + self._data_fetcher = _select_data_fetcher(trainer) self._is_fresh_start_epoch = True self._results.to(device=trainer.lightning_module.device) diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index d61afe4b98d58..b6f61b75036b6 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -136,7 +136,7 @@ def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None: sampler.set_epoch(epoch) -def _select_data_fetcher(trainer: "pl.Trainer", prefetch_batches: int = 0) -> _DataFetcher: +def _select_data_fetcher(trainer: "pl.Trainer") -> _DataFetcher: lightning_module = trainer.lightning_module if trainer.testing: step_fx_name = "test_step" @@ -153,7 +153,7 @@ def _select_data_fetcher(trainer: "pl.Trainer", prefetch_batches: int = 0) -> _D "this signature is experimental and the behavior is subject to change." ) return _DataLoaderIterDataFetcher() - return _PrefetchDataFetcher(prefetch_batches=prefetch_batches) + return _PrefetchDataFetcher() def _no_grad_context(loop_run: Callable) -> Callable: diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index a5c1462f7f4aa..07c6c1507c8d3 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -64,8 +64,9 @@ def generate(): # we can only know the last batch with sized iterables or when we prefetch is_last_batch = [False, False, prefetch_batches > 0 or dataset_cls is SizedDataset] - fetched = list(range(prefetch_batches + 1, 4)) - fetched += [3] * (3 - len(fetched)) + fetched = ( + [1, 2, 3] if dataset_cls is SizedDataset else [1, 2, 3, 3, 3, 3, 3][prefetch_batches : prefetch_batches + 3] + ) batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3] expected = list(zip(fetched, batches, is_last_batch)) assert len(expected) == 3