-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] Inter-batch-parellelism #8700
Changes from all commits
241ad1f
cfe0c08
132967d
45690e3
aa1709e
3cc770b
28c45ce
67f034a
989b7cc
7e08c87
ef34d31
1e6c977
08f7860
f0a89d3
46bdb43
c14aad5
8a4da32
1c795f7
5a372e2
417e30a
a593b54
94eec7f
b101cb0
c3578a5
e0c82b6
ee96e37
4eca409
53eb4d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, Generator, List, Tuple | ||
|
||
import torch | ||
|
||
import pytorch_lightning as pl | ||
|
||
|
||
def profiled_iterator(iterable, profiler): | ||
iterator = iter(iterable) | ||
while True: | ||
try: | ||
with profiler.profile("get_train_batch"): | ||
yield next(iterator) | ||
except StopIteration: | ||
return | ||
|
||
|
||
class LightningFetcher: | ||
|
||
""" | ||
This class is used to perform ``pre-fetching`` for the ``train`` dataloader | ||
and apply inter batch parallelism if enabled. | ||
Comment on lines
+34
to
+35
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this could/should be used for evaluation too to still overlap the forward with host to device transfer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, added in fault tolerant training PRs :) |
||
|
||
batch 0: [HtoD][forward][backward] | ||
batch 1: [HtoD][forward][backward] | ||
With parallelization, the latency of HtoD copy can be hidden: | ||
|
||
batch 0: [HtoD][forward][backward] | ||
batch 1: [HtoD] [forward][backward] | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataloader, | ||
batch_to_device: Callable, | ||
profiler: "pl.profiler.base.BaseProfiler", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could the profiler be optional? the lightning trainer would always set this, but this becomes a handy utility for users outside the project too (it might even draw them into the project) |
||
num_prefetch_batches: int = 1, | ||
) -> None: | ||
self.iterator = profiled_iterator(dataloader, profiler) | ||
self.profiler = profiler | ||
self.batch_to_device = batch_to_device | ||
self.batches: List = [] | ||
self.events: List = [] | ||
self.counter: int = 0 | ||
self.num_prefetch_batches = num_prefetch_batches | ||
|
||
def __iter__(self) -> Generator: | ||
self.counter = 1 | ||
if self.num_prefetch_batches == 1: | ||
return self.prefetch_single_batch_iterator() | ||
return self.prefetch_iterator() | ||
|
||
def add_event(self, event) -> None: | ||
self.events.append(event) | ||
|
||
def add_batch(self, batch) -> None: | ||
self.batches.append(batch) | ||
|
||
@staticmethod | ||
def start_record(event) -> None: | ||
event.record() | ||
|
||
def fetch_batch(self): | ||
return self.batches.pop(0) | ||
|
||
def wait(self) -> None: | ||
event = self.events.pop(0) | ||
event.wait() | ||
|
||
def prefetch_iterator(self) -> Any: | ||
cuda_stream = torch.cuda.Stream() | ||
|
||
done = False | ||
while not done: | ||
|
||
for _ in range(self.num_prefetch_batches + 1): | ||
if not done: | ||
with torch.cuda.stream(cuda_stream): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this validate somewhere that torch.cuda is available? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We validate it here as it is only supported on GPU devices https://github.com/PyTorchLightning/pytorch-lightning/blob/inter_batch_parallism/pytorch_lightning/trainer/trainer.py#L1344 |
||
batch_event = torch.cuda.Event() | ||
self.add_event(batch_event) | ||
try: | ||
batch = next(self.iterator) | ||
with self.profiler.profile("training_batch_to_device"): | ||
batch = self.batch_to_device(batch) | ||
self.add_batch(batch) | ||
self.start_record(batch_event) | ||
except StopIteration: | ||
done = True | ||
|
||
self.wait() | ||
batch = self.fetch_batch() | ||
# yield last and has next | ||
yield self.counter, batch, done | ||
self.counter += 1 | ||
|
||
def prefetch_single_batch_iterator(self) -> Generator[Tuple[Any, bool], None, None]: | ||
""" | ||
Returns an iterator that pre-fetches and caches the next item. | ||
The values are passed through from the given iterable with an added boolean indicating if this is the last item. | ||
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_ | ||
""" | ||
|
||
try: | ||
# the iterator may be empty from the beginning | ||
last = next(self.iterator) | ||
except StopIteration: | ||
return | ||
|
||
counter = 1 | ||
|
||
for val in self.iterator: | ||
# yield last and has next | ||
with self.profiler.profile("training_batch_to_device"): | ||
last = self.batch_to_device(last) | ||
yield counter, last, False | ||
last = val | ||
counter += 1 | ||
# yield last, no longer has next | ||
with self.profiler.profile("training_batch_to_device"): | ||
last = self.batch_to_device(last) | ||
yield counter, last, True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: certain classes are prefixed with
Lightning
while others aren't.Components with name:
utilities like this don't contain logic specific to the rest of the framework, so I wonder if we could call this just Prefetcher?
it can also help users feel like they're not needing to learn lightning-specific things
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would propose
DataFetcher
.