diff --git a/CHANGELOG.md b/CHANGELOG.md index 71a372bedbcdd..25eeedfbc850d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `DataLoaderIterDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020)) +- Added `DataFetcher` within `Fit / Evaluation` Loop ([#9047](https://github.com/PyTorchLightning/pytorch-lightning/pull/9047)) + + - Added a friendly error message when DDP attempts to spawn new distributed processes with rank > 0 ([#9005](https://github.com/PyTorchLightning/pytorch-lightning/pull/9005)) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index a048bc0c3a91c..163d7681f29a8 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -280,6 +280,8 @@ def _training_step( training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() + del step_kwargs + training_step_output = self.trainer.call_hook("training_step_end", training_step_output) _check_training_step_output(self.trainer.lightning_module, training_step_output) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 52998676e1923..3c24c921da389 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Iterator, List, Optional, Sequence, Union from deprecate.utils import void from torch.utils.data.dataloader import DataLoader @@ -20,7 +20,6 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -98,10 +97,13 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader""" void(*args, **kwargs) + dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - data_fetcher = DataFetcher() - data_fetcher.setup(dataloader) - dataloader_iter = enumerate(data_fetcher) + dataloader = self.trainer.data_connector.get_profiled_dataloader( + dataloader, dataloader_idx=self.current_dataloader_idx + ) + dataloader_iter = iter(dataloader) + dl_max_batches = self._max_batches[self.current_dataloader_idx] dl_outputs = self.epoch_loop.run( diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index b03b15d820230..eb3f9dad58bcf 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -91,8 +91,9 @@ def advance( if batch is None: raise StopIteration - with self.trainer.profiler.profile("evaluation_batch_to_device"): - batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) + if not self.trainer.data_connector.evaluation_data_fetcher.store_on_device: + with self.trainer.profiler.profile("evaluation_batch_to_device"): + batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 7242ccde42632..741a05cd5701e 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -132,14 +132,13 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: else: _, (batch, is_last) = next(dataloader_iter) - # ------------------------------------ - # TRAINING_STEP + TRAINING_STEP_END - # ------------------------------------ - # FIXME: Remove with InterBatchProcessor. - if not self.trainer.data_connector.data_fetcher.store_on_device: + if not self.trainer.data_connector.train_data_fetcher.store_on_device: with self.trainer.profiler.profile("training_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch) + # ------------------------------------ + # TRAINING_STEP + TRAINING_STEP_END + # ------------------------------------ self.batch_progress.increment_ready() with self.trainer.profiler.profile("run_training_batch"): diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index b77f186453c6a..8507809c361ec 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,7 +14,7 @@ import logging from contextlib import suppress -from typing import Optional +from typing import Iterator, Optional from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop @@ -192,12 +192,13 @@ def on_advance_start(self) -> None: def advance(self) -> None: """Runs one whole epoch.""" - train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) - train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) + dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) + dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader) + dataloader_iter = iter(dataloader) with self.trainer.profiler.profile("run_training_epoch"): # run train epoch - epoch_output = self.epoch_loop.run(train_dataloader) + epoch_output = self.epoch_loop.run(dataloader_iter) if epoch_output is None: return diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index 438d081026874..c1981173215ae 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -113,7 +113,7 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: Args: dataloader_iter: the iterator over the dataloader producing the new batch """ - _, (dataloader_iter, batch_idx, is_last) = next(dataloader_iter) + batch_idx, (dataloader_iter, is_last) = next(dataloader_iter) self.trainer.logger_connector.on_batch_start() response = self.trainer.call_hook("on_batch_start") diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index faa5e9070b18e..8d337b972dce2 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -11,22 +11,47 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +from functools import partial from typing import Callable, Iterable, Optional, Union import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher, InterBatchParallelDataFetcher +from pytorch_lightning.utilities.fetching import ( + AbstractDataFetcher, + DataFetcher, + DataLoaderIterDataFetcher, + InterBatchParallelDataFetcher, +) from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from pytorch_lightning.utilities.warnings import rank_zero_warn class DataConnector: - def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): + def __init__( + self, + trainer: "pl.Trainer", + multiple_trainloader_mode: str = "max_size_cycle", + train_data_fetcher: Optional[AbstractDataFetcher] = None, + validate_data_fetcher: Optional[AbstractDataFetcher] = None, + test_data_fetcher: Optional[AbstractDataFetcher] = None, + ): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode - self.data_fetcher: AbstractDataFetcher = DataFetcher() + + self.train_data_fetcher = train_data_fetcher + self.validate_data_fetcher = validate_data_fetcher + self.test_data_fetcher = test_data_fetcher + self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None + + @property + def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]: + if self.trainer.sanity_checking: + return self.sanity_check_data_fetcher + return self.test_data_fetcher if self.trainer.testing else self.validate_data_fetcher def on_trainer_init( self, @@ -66,15 +91,42 @@ def on_trainer_init( self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs self.trainer._is_data_prepared = False - def get_profiled_train_dataloader(self, train_dataloader) -> Iterable: - # FIXME: Temporary hack - if isinstance(self.data_fetcher, InterBatchParallelDataFetcher): - self.data_fetcher.setup(train_dataloader, batch_to_device=self.trainer.accelerator.batch_to_device) - else: - self.data_fetcher.setup(train_dataloader) - prefetcher_iter = iter(self.data_fetcher) - profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch") - return profiled_dl + def _check_training_step_requires_dataloader_iter(self) -> bool: + training_step_fx = getattr(self.trainer.lightning_module, "training_step") + contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True) + return contains_dataloader_iter + + def _select_data_fetcher(self) -> AbstractDataFetcher: + if self.trainer.sanity_checking: + return DataFetcher() + + if self.trainer.training and self._check_training_step_requires_dataloader_iter(): + rank_zero_warn( + "Found `dataloader_iter` argument in the `training_step`. Note that the support for " + "this signature is experimental and the behavior is subject to change." + ) + return DataLoaderIterDataFetcher() + elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": + # note: this is an experimental feature + if not self.trainer.training_type_plugin.on_gpu: + raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") + return InterBatchParallelDataFetcher() + + return DataFetcher() + + def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable: + stage: str = self.trainer.state.stage.value + data_fetcher = setattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher() + data_fetcher.setup( + dataloader, + stage=stage, + batch_to_device=partial(self.trainer.accelerator.batch_to_device, dataloader_idx=dataloader_idx), + profiler=self.trainer.profiler, + ) + setattr(self, f"{stage}_data_fetcher", data_fetcher) + if isinstance(data_fetcher, DataLoaderIterDataFetcher): + return data_fetcher + return enumerate(data_fetcher) def prepare_data(self) -> None: # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 59442730245e0..72f54a891cde3 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -47,15 +47,19 @@ class SimpleDataFetcher(AbstractDataFetcher): def fetching_function(self): while True: try: - yield next(self.dataloader_iter), False + return next(self.dataloader_iter), False except StopIteration: return None, True """ @abstractmethod - def fetching_function(self) -> Generator: + def fetching_function(self) -> Any: """Override with your own fetching logic.""" + @abstractmethod + def prefetching(self, prefetch_batches: int) -> None: + """Override with your own pre-fetching logic.""" + def __init__( self, prefetch_batches: int = 0, @@ -63,6 +67,7 @@ def __init__( if prefetch_batches < 0: raise MisconfigurationException("`prefetch_batches` should at least be 0.") + self.store_on_device = False self.prefetch_batches = prefetch_batches + 1 self.dataloader: Optional[Iterable] = None @@ -192,6 +197,10 @@ def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: self.reset() self.dataloader_iter = iter(self.dataloader) self._apply_patch() + self.prefetching(self.prefetch_batches) + return self + + def __next__(self): return self.fetching_function() def reset(self) -> None: @@ -241,34 +250,38 @@ def on_fetch_end(self, batch, on_fetch_start_output: Optional[Any] = None) -> No def wait(self) -> None: """Hook to override to indicate the `DataFetcher` to wait for an event.""" - def fetching_function(self) -> Generator: - self.done = False - while not self.done: - self._prefetching(self.prefetch_batches) - - while self.batches: - try: - yield_batch = self.pop_batch() - self._fetch_next_batch() - - # wait for batch to be available. - self.wait() - - # yield last and has next - yield (self.move_data_to_device(yield_batch) if not self.store_on_device else yield_batch, False) - except StopIteration: - self.batches.insert(0, yield_batch) - break - - yield from self._consume_prefetched_batches() - - def _prefetching(self, prefetch_batches: int) -> None: + def prefetching(self, prefetch_batches: int) -> None: for _ in range(prefetch_batches): try: self._fetch_next_batch() except StopIteration: break + def fetching_function(self) -> Optional[Tuple[Any, bool]]: + if self.done: + while self.batches: + return self._get_queued_batch() + raise StopIteration + else: + try: + yield_batch = self.pop_batch() + self._fetch_next_batch() + + # wait for batch to be available. + self.wait() + + # yield last and has next + return yield_batch, False + # FIXME: Why does this count as a python `referrers` ? + # return (self.move_data_to_device(yield_batch) if not self.store_on_device else yield_batch, False) + except StopIteration: + self.batches.insert(0, yield_batch) + self.done = True + return self._get_queued_batch() + + except IndexError: + raise StopIteration + @contextmanager def apply_profiler(self, name: str) -> Generator: if self.profiler: @@ -291,13 +304,13 @@ def _consume_prefetched_batches(self) -> Generator: while self.batches: yield from self._yield_batch() - def _yield_batch(self) -> Generator: + def _get_queued_batch(self) -> Tuple[Any, bool]: self.wait() batch = self.batches.pop(0) if not self.store_on_device: batch = self.move_data_to_device(batch) is_last = len(self.batches) == 0 - yield batch, is_last + return batch, is_last def move_data_to_device(self, batch: Any) -> Any: if self.batch_to_device: @@ -406,7 +419,15 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: ... """ - def fetching_function(self) -> Generator: - iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self)) + def __init__(self): + super().__init__() + # prevent calling ``move_batch_to_device``` + self.store_on_device = True + + def prefetching(self, prefetch_batches: int) -> None: + self.iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self)) + + def fetching_function(self): while not self.done: - yield iterator, self.fetched, self.done + return self.fetched, (self.iterator, self.done) + raise StopIteration diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index da9a4acabf8f8..b351165e03fd8 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from time import time from typing import Any +from unittest import mock import pytest import torch @@ -174,15 +176,22 @@ def test_dataloader(self): def test_trainer_num_prefetch_batches(tmpdir): model = RecommenderModel() - trainer_kwargs = dict(default_root_dir=tmpdir, max_epochs=1, gpus=1, limit_train_batches=3, limit_val_batches=0) - - t0 = time() - trainer = Trainer(**trainer_kwargs) - trainer.data_connector.data_fetcher = InterBatchParallelDataFetcher() - trainer.fit(model) - t1 = time() - global_step = trainer.global_step - assert isinstance(trainer.data_connector.data_fetcher, InterBatchParallelDataFetcher) + trainer_kwargs = dict( + default_root_dir=tmpdir, + max_epochs=1, + gpus=1, + limit_train_batches=4, + limit_val_batches=0, + num_sanity_val_steps=0, + ) + + with mock.patch.dict(os.environ, {"PL_INTER_BATCH_PARALLELISM": "1"}): + t0 = time() + trainer = Trainer(**trainer_kwargs) + trainer.fit(model) + t1 = time() + global_step = trainer.global_step + assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher) torch.cuda.synchronize() @@ -191,8 +200,8 @@ def test_trainer_num_prefetch_batches(tmpdir): trainer.fit(model) t3 = time() - assert global_step == trainer.global_step == 3 - assert isinstance(trainer.data_connector.data_fetcher, DataFetcher) + assert global_step == trainer.global_step == 4 + assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) ratio = (t3 - t2) / (t1 - t0) assert ratio > 1.1, ratio