Skip to content

Commit

Permalink
Showing 10 changed files with 160 additions and 70 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -63,6 +63,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `DataLoaderIterDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020))


- Added `DataFetcher` within `Fit / Evaluation` Loop ([#9047](https://github.com/PyTorchLightning/pytorch-lightning/pull/9047))


- Added a friendly error message when DDP attempts to spawn new distributed processes with rank > 0 ([#9005](https://github.com/PyTorchLightning/pytorch-lightning/pull/9005))


2 changes: 2 additions & 0 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
@@ -280,6 +280,8 @@ def _training_step(
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()

del step_kwargs

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

_check_training_step_output(self.trainer.lightning_module, training_step_output)
12 changes: 7 additions & 5 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
@@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Optional, Sequence, Union
from typing import Any, Iterator, List, Optional, Sequence, Union

from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT

@@ -98,10 +97,13 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None:
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader"""
void(*args, **kwargs)

dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
data_fetcher = DataFetcher()
data_fetcher.setup(dataloader)
dataloader_iter = enumerate(data_fetcher)
dataloader = self.trainer.data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=self.current_dataloader_idx
)
dataloader_iter = iter(dataloader)

dl_max_batches = self._max_batches[self.current_dataloader_idx]

dl_outputs = self.epoch_loop.run(
5 changes: 3 additions & 2 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
@@ -91,8 +91,9 @@ def advance(
if batch is None:
raise StopIteration

with self.trainer.profiler.profile("evaluation_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)
if not self.trainer.data_connector.evaluation_data_fetcher.store_on_device:
with self.trainer.profiler.profile("evaluation_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)

self.batch_progress.increment_ready()

9 changes: 4 additions & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
@@ -132,14 +132,13 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
else:
_, (batch, is_last) = next(dataloader_iter)

# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
# FIXME: Remove with InterBatchProcessor.
if not self.trainer.data_connector.data_fetcher.store_on_device:
if not self.trainer.data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch)

# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
self.batch_progress.increment_ready()

with self.trainer.profiler.profile("run_training_batch"):
9 changes: 5 additions & 4 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@

import logging
from contextlib import suppress
from typing import Optional
from typing import Iterator, Optional

from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
@@ -192,12 +192,13 @@ def on_advance_start(self) -> None:

def advance(self) -> None:
"""Runs one whole epoch."""
train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader)
dataloader_iter = iter(dataloader)

with self.trainer.profiler.profile("run_training_epoch"):
# run train epoch
epoch_output = self.epoch_loop.run(train_dataloader)
epoch_output = self.epoch_loop.run(dataloader_iter)

if epoch_output is None:
return
Original file line number Diff line number Diff line change
@@ -113,7 +113,7 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]:
Args:
dataloader_iter: the iterator over the dataloader producing the new batch
"""
_, (dataloader_iter, batch_idx, is_last) = next(dataloader_iter)
batch_idx, (dataloader_iter, is_last) = next(dataloader_iter)

self.trainer.logger_connector.on_batch_start()
response = self.trainer.call_hook("on_batch_start")
78 changes: 65 additions & 13 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
@@ -11,22 +11,47 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from functools import partial
from typing import Callable, Iterable, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher, InterBatchParallelDataFetcher
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
DataFetcher,
DataLoaderIterDataFetcher,
InterBatchParallelDataFetcher,
)
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import rank_zero_warn


class DataConnector:
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
def __init__(
self,
trainer: "pl.Trainer",
multiple_trainloader_mode: str = "max_size_cycle",
train_data_fetcher: Optional[AbstractDataFetcher] = None,
validate_data_fetcher: Optional[AbstractDataFetcher] = None,
test_data_fetcher: Optional[AbstractDataFetcher] = None,
):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode
self.data_fetcher: AbstractDataFetcher = DataFetcher()

self.train_data_fetcher = train_data_fetcher
self.validate_data_fetcher = validate_data_fetcher
self.test_data_fetcher = test_data_fetcher
self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None

@property
def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]:
if self.trainer.sanity_checking:
return self.sanity_check_data_fetcher
return self.test_data_fetcher if self.trainer.testing else self.validate_data_fetcher

def on_trainer_init(
self,
@@ -66,15 +91,42 @@ def on_trainer_init(
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False

def get_profiled_train_dataloader(self, train_dataloader) -> Iterable:
# FIXME: Temporary hack
if isinstance(self.data_fetcher, InterBatchParallelDataFetcher):
self.data_fetcher.setup(train_dataloader, batch_to_device=self.trainer.accelerator.batch_to_device)
else:
self.data_fetcher.setup(train_dataloader)
prefetcher_iter = iter(self.data_fetcher)
profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch")
return profiled_dl
def _check_training_step_requires_dataloader_iter(self) -> bool:
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True)
return contains_dataloader_iter

def _select_data_fetcher(self) -> AbstractDataFetcher:
if self.trainer.sanity_checking:
return DataFetcher()

if self.trainer.training and self._check_training_step_requires_dataloader_iter():
rank_zero_warn(
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return DataLoaderIterDataFetcher()
elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
# note: this is an experimental feature
if not self.trainer.training_type_plugin.on_gpu:
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
return InterBatchParallelDataFetcher()

return DataFetcher()

def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable:
stage: str = self.trainer.state.stage.value
data_fetcher = setattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher()
data_fetcher.setup(
dataloader,
stage=stage,
batch_to_device=partial(self.trainer.accelerator.batch_to_device, dataloader_idx=dataloader_idx),
profiler=self.trainer.profiler,
)
setattr(self, f"{stage}_data_fetcher", data_fetcher)
if isinstance(data_fetcher, DataLoaderIterDataFetcher):
return data_fetcher
return enumerate(data_fetcher)

def prepare_data(self) -> None:
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
79 changes: 50 additions & 29 deletions pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
@@ -47,22 +47,27 @@ class SimpleDataFetcher(AbstractDataFetcher):
def fetching_function(self):
while True:
try:
yield next(self.dataloader_iter), False
return next(self.dataloader_iter), False
except StopIteration:
return None, True
"""

@abstractmethod
def fetching_function(self) -> Generator:
def fetching_function(self) -> Any:
"""Override with your own fetching logic."""

@abstractmethod
def prefetching(self, prefetch_batches: int) -> None:
"""Override with your own pre-fetching logic."""

def __init__(
self,
prefetch_batches: int = 0,
) -> None:
if prefetch_batches < 0:
raise MisconfigurationException("`prefetch_batches` should at least be 0.")

self.store_on_device = False
self.prefetch_batches = prefetch_batches + 1

self.dataloader: Optional[Iterable] = None
@@ -192,6 +197,10 @@ def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
self.reset()
self.dataloader_iter = iter(self.dataloader)
self._apply_patch()
self.prefetching(self.prefetch_batches)
return self

def __next__(self):
return self.fetching_function()

def reset(self) -> None:
@@ -241,34 +250,38 @@ def on_fetch_end(self, batch, on_fetch_start_output: Optional[Any] = None) -> No
def wait(self) -> None:
"""Hook to override to indicate the `DataFetcher` to wait for an event."""

def fetching_function(self) -> Generator:
self.done = False
while not self.done:
self._prefetching(self.prefetch_batches)

while self.batches:
try:
yield_batch = self.pop_batch()
self._fetch_next_batch()

# wait for batch to be available.
self.wait()

# yield last and has next
yield (self.move_data_to_device(yield_batch) if not self.store_on_device else yield_batch, False)
except StopIteration:
self.batches.insert(0, yield_batch)
break

yield from self._consume_prefetched_batches()

def _prefetching(self, prefetch_batches: int) -> None:
def prefetching(self, prefetch_batches: int) -> None:
for _ in range(prefetch_batches):
try:
self._fetch_next_batch()
except StopIteration:
break

def fetching_function(self) -> Optional[Tuple[Any, bool]]:
if self.done:
while self.batches:
return self._get_queued_batch()
raise StopIteration
else:
try:
yield_batch = self.pop_batch()
self._fetch_next_batch()

# wait for batch to be available.
self.wait()

# yield last and has next
return yield_batch, False
# FIXME: Why does this count as a python `referrers` ?
# return (self.move_data_to_device(yield_batch) if not self.store_on_device else yield_batch, False)
except StopIteration:
self.batches.insert(0, yield_batch)
self.done = True
return self._get_queued_batch()

except IndexError:
raise StopIteration

@contextmanager
def apply_profiler(self, name: str) -> Generator:
if self.profiler:
@@ -291,13 +304,13 @@ def _consume_prefetched_batches(self) -> Generator:
while self.batches:
yield from self._yield_batch()

def _yield_batch(self) -> Generator:
def _get_queued_batch(self) -> Tuple[Any, bool]:
self.wait()
batch = self.batches.pop(0)
if not self.store_on_device:
batch = self.move_data_to_device(batch)
is_last = len(self.batches) == 0
yield batch, is_last
return batch, is_last

def move_data_to_device(self, batch: Any) -> Any:
if self.batch_to_device:
@@ -406,7 +419,15 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None:
...
"""

def fetching_function(self) -> Generator:
iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self))
def __init__(self):
super().__init__()
# prevent calling ``move_batch_to_device```
self.store_on_device = True

def prefetching(self, prefetch_batches: int) -> None:
self.iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self))

def fetching_function(self):
while not self.done:
yield iterator, self.fetched, self.done
return self.fetched, (self.iterator, self.done)
raise StopIteration
Loading

0 comments on commit 92c7eec

Please sign in to comment.