Skip to content

Commit

Permalink
Refactor CombinedLoader using pytrees (#16714)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Feb 11, 2023
1 parent 984f49f commit d660379
Show file tree
Hide file tree
Showing 12 changed files with 271 additions and 448 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Renamed `TQDMProgressBar.main_progress_bar` to `TQDMProgressBar.train_progress_bar` ([#16695](https://github.com/Lightning-AI/lightning/pull/16695))

- Marked `lightning.pytorch.utilities.supporters.CombinedDataset` as protected ([#16714](https://github.com/Lightning-AI/lightning/pull/16714))

- Disabled strict loading in multiprocessing launcher ("ddp_spawn", etc.) when loading weights back into the main process ([#16365](https://github.com/Lightning-AI/lightning/pull/16365))

Expand Down Expand Up @@ -266,6 +267,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed the unused `lightning.pytorch.utilities.metrics.metrics_to_scalars` function ([#16681](https://github.com/Lightning-AI/lightning/pull/16681))

- Removed the unused `lightning.pytorch.utilities.supporters.{SharedCycleIteratorState,CombinedLoaderIterator}` classes ([#16714](https://github.com/Lightning-AI/lightning/pull/16714))

### Fixed

-
Expand Down
12 changes: 2 additions & 10 deletions src/lightning/pytorch/loops/fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.utils.data.dataloader import DataLoader

from lightning.fabric.utilities.data import has_len
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.trainer.supporters import _shutdown_workers_and_reset_iterator, CombinedLoader
from lightning.pytorch.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -45,14 +45,6 @@ def dataloader(self) -> Iterable:
)
return self._dataloader

@property
def loader_iters(self) -> Any:
if self.dataloader_iter is None:
raise MisconfigurationException("The `dataloader_iter` isn't available outside the __iter__ context.")
if isinstance(self.dataloader, CombinedLoader):
return self.dataloader_iter.loader_iters
return self.dataloader_iter

def __iter__(self) -> "_DataFetcher":
self.reset()
self.dataloader_iter = iter(self.dataloader)
Expand Down Expand Up @@ -80,7 +72,7 @@ def teardown(self) -> None:
if isinstance(self._dataloader, CombinedLoader):
self._dataloader.reset()
if isinstance(self._dataloader, DataLoader):
CombinedLoader._shutdown_workers_and_reset_iterator(self._dataloader)
_shutdown_workers_and_reset_iterator(self._dataloader)
self.dataloader_iter = None


Expand Down
25 changes: 5 additions & 20 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper
from lightning.pytorch.strategies import DDPSpawnStrategy
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.trainer.supporters import CombinedLoader, CycleIterator
from lightning.pytorch.trainer.supporters import _LITERAL_SUPPORTED_MODES, CombinedLoader
from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader, has_len_all_ranks
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
Expand All @@ -40,7 +40,7 @@


class DataConnector:
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: _LITERAL_SUPPORTED_MODES = "max_size_cycle"):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode
self._train_dataloader_source = _DataLoaderSource(None, "")
Expand Down Expand Up @@ -239,28 +239,17 @@ def _prepare_dataloader(
"""This function handles the following functionalities:
- Injecting a `DistributedDataSamplerWrapper` into the `DataLoader` if on a distributed environment
- Wrapping the datasets and samplers into fault-tolerant components
- Wrapping the dataloader based on strategy-specific logic
"""
if isinstance(dataloader, CombinedLoader):
# apply `_prepare_dataloader` on all the collection of loaders
dataloader.loaders = apply_to_collection(
dataloader.loaders, (DataLoader, CycleIterator), self._prepare_dataloader, shuffle, mode=mode
)
# the length need to recomputed across all dataloaders in case of special behavior.
dataloader._apply_cycle_iterator_length()
for i, dl in enumerate(dataloader._loaders_flattened):
dataloader._update_index(self._prepare_dataloader(dl, shuffle=shuffle, mode=mode), i)
return dataloader

# don't do anything if it's not a dataloader
if not isinstance(dataloader, (DataLoader, CycleIterator)):
if not isinstance(dataloader, DataLoader):
return dataloader

cycle_iterator: Optional[CycleIterator] = None

if isinstance(dataloader, CycleIterator):
cycle_iterator = dataloader
dataloader = dataloader.loader

if (
self._requires_distributed_sampler(dataloader) # sets the distributed sampler
or mode == RunningStage.PREDICTING # to track indices for the predictions
Expand All @@ -277,10 +266,6 @@ def _prepare_dataloader(

dataloader = self.trainer.strategy.process_dataloader(dataloader)

if cycle_iterator is not None:
cycle_iterator.loader = dataloader
return cycle_iterator

return dataloader

def _resolve_sampler(
Expand Down
Loading

0 comments on commit d660379

Please sign in to comment.