Skip to content

Commit

Permalink
Prefetch if it's not a sized iterable (#16776)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Feb 16, 2023
1 parent c9452df commit 51d44f5
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 19 deletions.
8 changes: 1 addition & 7 deletions src/lightning/pytorch/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,6 @@ def dataloaders(self) -> Sequence[DataLoader]:
return []
return dataloaders

@property
def prefetch_batches(self) -> int:
batches = self.trainer.num_test_batches if self.trainer.testing else self.trainer.num_val_batches
is_unsized = batches[self.current_dataloader_idx] == float("inf")
return int(is_unsized)

@property
def done(self) -> bool:
"""Returns whether all dataloaders are processed or evaluation should be skipped altogether."""
Expand Down Expand Up @@ -126,7 +120,7 @@ def reset(self) -> None:
def on_run_start(self) -> None:
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
hooks."""
self._data_fetcher = _select_data_fetcher(self.trainer, prefetch_batches=self.prefetch_batches)
self._data_fetcher = _select_data_fetcher(self.trainer)

# hook
self._on_evaluation_model_eval()
Expand Down
8 changes: 6 additions & 2 deletions src/lightning/pytorch/loops/fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class _PrefetchDataFetcher(_DataFetcher):
Args:
prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
whether a batch is the last one (available with :attr:`self.done`) under any training setup.
whether a batch is the last one (available with :attr:`self.done`) when the length is not available.
"""

def __init__(self, prefetch_batches: int = 1) -> None:
Expand All @@ -98,6 +98,10 @@ def setup(self, dataloader: Iterable) -> None:

def __iter__(self) -> "_PrefetchDataFetcher":
super().__iter__()
if self._has_len:
# ignore pre-fetching, it's not necessary
return self
# prefetch batches to know when the iterator will be exhausted in advance
iterator = self.dataloader_iter
assert iterator is not None
for _ in range(self.prefetch_batches):
Expand Down Expand Up @@ -143,7 +147,7 @@ def _fetch_next_batch(self, iterator: Iterator) -> None:
finally:
self._stop_profiler()
self.fetched += 1
if not self.prefetch_batches and self._has_len:
if self._has_len:
# when we don't prefetch but the dataloader is sized, we use the length for `done`
dataloader = self.dataloader
assert isinstance(dataloader, Sized) # `_has_len` is True
Expand Down
7 changes: 1 addition & 6 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,6 @@ def restarting(self, restarting: bool) -> None:
restarting = restarting and epoch_unfinished or self._iteration_based_training()
_Loop.restarting.fset(self, restarting) # call the parent setter

@property
def prefetch_batches(self) -> int:
is_unsized = self.trainer.num_training_batches == float("inf")
return int(is_unsized)

@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
Expand Down Expand Up @@ -219,7 +214,7 @@ def on_run_start(self) -> None:
if self.epoch_loop._should_check_val_epoch():
self.epoch_loop.val_loop._reload_evaluation_dataloaders()

self._data_fetcher = _select_data_fetcher(trainer, self.prefetch_batches)
self._data_fetcher = _select_data_fetcher(trainer)

self._is_fresh_start_epoch = True
self._results.to(device=trainer.lightning_module.device)
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None:
sampler.set_epoch(epoch)


def _select_data_fetcher(trainer: "pl.Trainer", prefetch_batches: int = 0) -> _DataFetcher:
def _select_data_fetcher(trainer: "pl.Trainer") -> _DataFetcher:
lightning_module = trainer.lightning_module
if trainer.testing:
step_fx_name = "test_step"
Expand All @@ -153,7 +153,7 @@ def _select_data_fetcher(trainer: "pl.Trainer", prefetch_batches: int = 0) -> _D
"this signature is experimental and the behavior is subject to change."
)
return _DataLoaderIterDataFetcher()
return _PrefetchDataFetcher(prefetch_batches=prefetch_batches)
return _PrefetchDataFetcher()


def _no_grad_context(loop_run: Callable) -> Callable:
Expand Down
5 changes: 3 additions & 2 deletions tests/tests_pytorch/loops/test_fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ def generate():

# we can only know the last batch with sized iterables or when we prefetch
is_last_batch = [False, False, prefetch_batches > 0 or dataset_cls is SizedDataset]
fetched = list(range(prefetch_batches + 1, 4))
fetched += [3] * (3 - len(fetched))
fetched = (
[1, 2, 3] if dataset_cls is SizedDataset else [1, 2, 3, 3, 3, 3, 3][prefetch_batches : prefetch_batches + 3]
)
batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3]
expected = list(zip(fetched, batches, is_last_batch))
assert len(expected) == 3
Expand Down

0 comments on commit 51d44f5

Please sign in to comment.