Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove redundant iterator call to data fetcher in loops #9117

Merged
merged 4 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,10 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
dataloader_iter = iter(dataloader)

dl_max_batches = self._max_batches[dataloader_idx]

dl_outputs = self.epoch_loop.run(dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders)
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)

# store batch level output per dataloader
if self.should_track_batch_outputs_for_epoch_end:
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.types import STEP_OUTPUT

Expand Down Expand Up @@ -58,12 +59,12 @@ def reset(self) -> None:
self.batch_progress.current.reset()

def on_run_start(
self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
) -> None:
"""Adds the passed arguments to the loop's state if necessary

Args:
dataloader_iter: iterator over the dataloader
data_fetcher: the current data_fetcher wrapping the dataloader
dataloader_idx: index of the current dataloader
dl_max_batches: maximum number of batches the dataloader can produce
num_dataloaders: the total number of dataloaders
Expand All @@ -72,10 +73,10 @@ def on_run_start(
self._dl_max_batches = dl_max_batches
self._num_dataloaders = num_dataloaders

self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_progress.current.ready)
self.dataloader_iter = _prepare_dataloader_iter(data_fetcher, self.batch_progress.current.ready)

def advance(
self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
) -> None:
"""Calls the evaluation step with the corresponding hooks and updates the logger connector.

Expand All @@ -88,7 +89,7 @@ def advance(
Raises:
StopIteration: If the current batch is None
"""
void(dataloader_iter, dl_max_batches, num_dataloaders)
void(data_fetcher, dl_max_batches, num_dataloaders)

batch_idx, (batch, _) = next(self.dataloader_iter)

Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,11 @@ def on_advance_start(self) -> None:
def advance(self) -> None:
"""Runs one whole epoch."""
dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader)
dataloader_iter = iter(dataloader)
data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader)

with self.trainer.profiler.profile("run_training_epoch"):
# run train epoch
epoch_output = self.epoch_loop.run(dataloader_iter)
epoch_output = self.epoch_loop.run(data_fetcher)

if epoch_output is None:
return
Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.types import STEP_OUTPUT

Expand Down Expand Up @@ -105,9 +105,11 @@ def _process_training_step_output(
return results, hiddens


def _prepare_dataloader_iter(dataloader_iter: Iterator, batch_idx: int) -> Iterator:
def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much do we actually need to expose the "AbstractDataFetcher" here versus just an iterator? It seems to me that we're really not adding much value by exposing a new interface here.

The custom logic below really should be part of the iterator implementation itself as i comment below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the iterator type annotation would have sufficed here

"""Attach the dataloader"""
if not isinstance(dataloader_iter, DataLoaderIterDataFetcher):
dataloader_iter = enumerate(dataloader_iter, batch_idx)
# restore iteration
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not simply move this into the contract of this iterator? basically make the contract such that these data fetchers will always return the batch_idx as part of it.

Copy link
Contributor Author

@awaelchli awaelchli Aug 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I suppose could be done like that :)

# restore iteration
dataloader_iter = enumerate(data_fetcher, batch_idx)
else:
dataloader_iter = iter(data_fetcher)
return dataloader_iter