Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Inter-batch-parellelism #8700

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
241ad1f
update
tchaton Aug 3, 2021
cfe0c08
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
132967d
update
tchaton Aug 3, 2021
45690e3
Merge branch 'inter_batch_parallism' of https://github.com/PyTorchLig…
tchaton Aug 3, 2021
aa1709e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
3cc770b
update
tchaton Aug 3, 2021
28c45ce
Merge branch 'inter_batch_parallism' of https://github.com/PyTorchLig…
tchaton Aug 3, 2021
67f034a
bad merge
tchaton Aug 3, 2021
989b7cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
7e08c87
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
ef34d31
improve test
tchaton Aug 3, 2021
1e6c977
Merge branch 'inter_batch_parallism' of https://github.com/PyTorchLig…
tchaton Aug 3, 2021
08f7860
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
f0a89d3
Raise exception for non GPUs
kaushikb11 Aug 4, 2021
46bdb43
Code refactor
kaushikb11 Aug 4, 2021
c14aad5
Update test & dataloader fetcher
kaushikb11 Aug 4, 2021
8a4da32
Fix profiled iterator
kaushikb11 Aug 5, 2021
1c795f7
Attempt to fix the issue with hacky iterator
Aug 6, 2021
5a372e2
Update LightningFetcher & use num_prefetch_batches instead
kaushikb11 Aug 9, 2021
417e30a
Merge branch 'inter_batch_parallism' of https://github.com/PyTorchLig…
kaushikb11 Aug 9, 2021
a593b54
Update evvent usage
kaushikb11 Aug 9, 2021
94eec7f
Update tests
kaushikb11 Aug 9, 2021
b101cb0
Update defaults
kaushikb11 Aug 9, 2021
c3578a5
code health fix
kaushikb11 Aug 9, 2021
e0c82b6
Merge branch 'master' into inter_batch_parallism
kaushikb11 Aug 9, 2021
ee96e37
Update last for LightningFetcher
kaushikb11 Aug 10, 2021
4eca409
Merge branch 'inter_batch_parallism' of https://github.com/PyTorchLig…
kaushikb11 Aug 10, 2021
53eb4d7
Update LightingFetcher
kaushikb11 Aug 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,12 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
Raises:
StopIteration: When the epoch is canceled by the user returning -1
"""
_, (batch, is_last) = next(dataloader_iter)
_, batch, is_last = next(dataloader_iter)
self.is_last_batch = is_last

# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=self._dataloader_idx)

self.batch_progress.increment_ready()

with self.trainer.profiler.profile("run_training_batch"):
Expand Down
23 changes: 16 additions & 7 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union
from typing import Iterator, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.trainer.supporters import prefetch_iterator
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.dataloader_fetcher import LightningFetcher
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS


class DataConnector:
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
def __init__(
self,
trainer: "pl.Trainer",
multiple_trainloader_mode: str = "max_size_cycle",
num_prefetch_batches: int = 1,
):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode
self.num_prefetch_batches = num_prefetch_batches

def on_trainer_init(
self,
Expand Down Expand Up @@ -60,11 +66,14 @@ def on_trainer_init(
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False

def get_profiled_train_dataloader(self, train_dataloader):
profiled_dl = self.trainer.profiler.profile_iterable(
enumerate(prefetch_iterator(train_dataloader)), "get_train_batch"
def get_profiled_train_dataloader(self, train_dataloader) -> Iterator:
fetcher = LightningFetcher(
train_dataloader,
self.trainer.accelerator.batch_to_device,
self.trainer.profiler,
self.num_prefetch_batches,
)
return profiled_dl
return iter(fetcher)

def prepare_data(self) -> None:
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import os
from collections.abc import Iterable, Iterator, Mapping, Sequence
from contextlib import contextmanager
from functools import partial
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
from typing import Any, Callable, ContextManager, Dict, Generator, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down
19 changes: 17 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, IPUAccelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.loggers import LightningLoggerBase
Expand Down Expand Up @@ -159,6 +159,7 @@ def __init__(
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
stochastic_weight_avg: bool = False,
num_prefetch_batches: int = 1,
):
r"""
Customize every aspect of training via flags
Expand Down Expand Up @@ -334,6 +335,10 @@ def __init__(
stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA)
<https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/>_`

num_prefetch_batches: The most common approach is to process input batches sequentially.
However, setting num_prefetch_batches, enables hiding the next n batches transfer latency
to device behind the forward and backward call from the model. Only supported on CUDA devices.

"""
super().__init__()
Trainer._log_api_event("init")
Expand All @@ -344,7 +349,7 @@ def __init__(
# init connectors
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
self.data_connector = DataConnector(self, multiple_trainloader_mode)
self.data_connector = DataConnector(self, multiple_trainloader_mode, num_prefetch_batches)
self.optimizer_connector = OptimizerConnector(self)

self.accelerator_connector = AcceleratorConnector(
Expand All @@ -366,6 +371,9 @@ def __init__(
amp_level,
plugins,
)

self._validate_num_prefetch_batches(num_prefetch_batches)

self.logger_connector = LoggerConnector(self, log_gpu_memory)
self.model_connector = ModelConnector(self)
self.callback_connector = CallbackConnector(self)
Expand Down Expand Up @@ -1333,6 +1341,13 @@ def _log_device_info(self) -> None:
" `Trainer(ipus=8)` or script `--ipus=8`."
)

def _validate_num_prefetch_batches(self, num_prefetch_batches: int) -> None:
if num_prefetch_batches > 1 and not isinstance(self.accelerator, GPUAccelerator):
raise MisconfigurationException(
f"You have passed `Trainer(num_prefetch_batches={num_prefetch_batches})`"
" but it is only supported on GPUs"
)

def _on_exception(self):
if not _fault_tolerant_enabled():
return
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy

from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
from pytorch_lightning.utilities.dataloader_fetcher import LightningFetcher
from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only # noqa: F401
from pytorch_lightning.utilities.enums import ( # noqa: F401
AMPType,
Expand Down
134 changes: 134 additions & 0 deletions pytorch_lightning/utilities/dataloader_fetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Generator, List, Tuple

import torch

import pytorch_lightning as pl


def profiled_iterator(iterable, profiler):
iterator = iter(iterable)
while True:
try:
with profiler.profile("get_train_batch"):
yield next(iterator)
except StopIteration:
return


class LightningFetcher:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: certain classes are prefixed with Lightning while others aren't.
Components with name:

  • LightningModule, LightningDataModule, LightningLoggerBase
  • without: Trainer, Callbacks, Logger, Profiler

utilities like this don't contain logic specific to the rest of the framework, so I wonder if we could call this just Prefetcher?

it can also help users feel like they're not needing to learn lightning-specific things

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would propose DataFetcher.


"""
This class is used to perform ``pre-fetching`` for the ``train`` dataloader
and apply inter batch parallelism if enabled.
Comment on lines +34 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could/should be used for evaluation too to still overlap the forward with host to device transfer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, added in fault tolerant training PRs :)


batch 0: [HtoD][forward][backward]
batch 1: [HtoD][forward][backward]
With parallelization, the latency of HtoD copy can be hidden:

batch 0: [HtoD][forward][backward]
batch 1: [HtoD] [forward][backward]
"""

def __init__(
self,
dataloader,
batch_to_device: Callable,
profiler: "pl.profiler.base.BaseProfiler",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could the profiler be optional? the lightning trainer would always set this, but this becomes a handy utility for users outside the project too (it might even draw them into the project)

num_prefetch_batches: int = 1,
) -> None:
self.iterator = profiled_iterator(dataloader, profiler)
self.profiler = profiler
self.batch_to_device = batch_to_device
self.batches: List = []
self.events: List = []
self.counter: int = 0
self.num_prefetch_batches = num_prefetch_batches

def __iter__(self) -> Generator:
self.counter = 1
if self.num_prefetch_batches == 1:
return self.prefetch_single_batch_iterator()
return self.prefetch_iterator()

def add_event(self, event) -> None:
self.events.append(event)

def add_batch(self, batch) -> None:
self.batches.append(batch)

@staticmethod
def start_record(event) -> None:
event.record()

def fetch_batch(self):
return self.batches.pop(0)

def wait(self) -> None:
event = self.events.pop(0)
event.wait()

def prefetch_iterator(self) -> Any:
cuda_stream = torch.cuda.Stream()

done = False
while not done:

for _ in range(self.num_prefetch_batches + 1):
if not done:
with torch.cuda.stream(cuda_stream):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this validate somewhere that torch.cuda is available?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_event = torch.cuda.Event()
self.add_event(batch_event)
try:
batch = next(self.iterator)
with self.profiler.profile("training_batch_to_device"):
batch = self.batch_to_device(batch)
self.add_batch(batch)
self.start_record(batch_event)
except StopIteration:
done = True

self.wait()
batch = self.fetch_batch()
# yield last and has next
yield self.counter, batch, done
self.counter += 1

def prefetch_single_batch_iterator(self) -> Generator[Tuple[Any, bool], None, None]:
"""
Returns an iterator that pre-fetches and caches the next item.
The values are passed through from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_
"""

try:
# the iterator may be empty from the beginning
last = next(self.iterator)
except StopIteration:
return

counter = 1

for val in self.iterator:
# yield last and has next
with self.profiler.profile("training_batch_to_device"):
last = self.batch_to_device(last)
yield counter, last, False
last = val
counter += 1
# yield last, no longer has next
with self.profiler.profile("training_batch_to_device"):
last = self.batch_to_device(last)
yield counter, last, True
Loading