Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3/n inter batch parallelism #9052

Merged
merged 96 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from 84 commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
c7d22fe
update
tchaton Aug 20, 2021
60df25a
resolve tests
tchaton Aug 20, 2021
ea3311e
update
tchaton Aug 20, 2021
862349e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2021
d551859
update
tchaton Aug 20, 2021
6849c62
Merge branch '1/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 20, 2021
c190af5
update
tchaton Aug 20, 2021
7d67e42
update
tchaton Aug 20, 2021
fadeddc
update
tchaton Aug 20, 2021
53a32c1
update on comments
tchaton Aug 23, 2021
28b066f
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
tchaton Aug 23, 2021
1bc068f
update
tchaton Aug 23, 2021
b094b34
update on comments
tchaton Aug 23, 2021
27a924d
typo
tchaton Aug 23, 2021
617333d
resolve bug
tchaton Aug 23, 2021
6eeba87
update
tchaton Aug 23, 2021
0326a33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
06dae29
update on comments
tchaton Aug 23, 2021
9ea1953
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
becdb3e
update
tchaton Aug 23, 2021
4e578f2
Merge branch '1/n_inter_batch_parallelism' into 2/n_inter_batch_paral…
tchaton Aug 23, 2021
9bd8094
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
e9d6760
update
tchaton Aug 23, 2021
842a6ae
update
tchaton Aug 23, 2021
a70c052
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
d08976f
resolve tests
tchaton Aug 23, 2021
ae3ac40
resolve
tchaton Aug 23, 2021
7f0b4b4
update
tchaton Aug 23, 2021
e7272e3
update
tchaton Aug 23, 2021
99a8853
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
3bd09e1
update
tchaton Aug 23, 2021
c7b5ff0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
f571f8c
add teardown
tchaton Aug 23, 2021
c34c914
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 23, 2021
1b5b911
update
tchaton Aug 23, 2021
403ef3c
update on comments
tchaton Aug 23, 2021
96ee949
add back comment
tchaton Aug 23, 2021
68acd44
Fix diff
carmocca Aug 23, 2021
3c1eca6
Merge branch 'master' into 2/n_inter_batch_parallelism
tchaton Aug 23, 2021
a995764
update on comments
tchaton Aug 23, 2021
307f446
resolve tests
tchaton Aug 23, 2021
d69b6fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
09cfa02
Merge branch '2/n_inter_batch_parallelism' into 3/n_inter_batch_paral…
tchaton Aug 23, 2021
104256c
update
tchaton Aug 23, 2021
89fce20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
e415c60
update
tchaton Aug 23, 2021
126818b
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 23, 2021
d81c971
cleanup
tchaton Aug 23, 2021
d473903
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
1ed2f31
update
tchaton Aug 23, 2021
87432e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
fb98c86
update
tchaton Aug 23, 2021
54fda07
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 23, 2021
5bf9090
resolve memory leak
tchaton Aug 23, 2021
9920162
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
899c349
update on comments
tchaton Aug 23, 2021
eaf2fc7
Merge branch '2/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 23, 2021
589c717
resolve tests
tchaton Aug 23, 2021
38fa806
update
tchaton Aug 23, 2021
c2ec026
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
1e1a482
update
tchaton Aug 23, 2021
e26e27d
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 23, 2021
71755ca
update
tchaton Aug 23, 2021
8bbb957
update on comments
tchaton Aug 24, 2021
5ba95ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2021
63c11e1
update
tchaton Aug 24, 2021
fe92512
update
tchaton Aug 24, 2021
5a738de
update
tchaton Aug 24, 2021
a4a2171
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2021
29a7244
improve test
tchaton Aug 24, 2021
7f7b648
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 24, 2021
985e47a
update
tchaton Aug 24, 2021
95726a8
cleanup
tchaton Aug 24, 2021
fa75fb2
more cleanup
tchaton Aug 24, 2021
0f74e4f
resolve batch_idx sync
tchaton Aug 24, 2021
6667588
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2021
570edff
drop weird test
tchaton Aug 24, 2021
44d5052
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 24, 2021
89f372a
update
tchaton Aug 24, 2021
3fddde7
update on comments
tchaton Aug 24, 2021
cd1c8c1
resolve typing
tchaton Aug 24, 2021
3fb109d
updte
tchaton Aug 24, 2021
4de4da5
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 24, 2021
c00278d
update on comments
tchaton Aug 24, 2021
edaf099
update
tchaton Aug 24, 2021
9706854
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2021
1ad63bc
resolve comments
tchaton Aug 24, 2021
91f5954
updte
tchaton Aug 24, 2021
496f419
remove + 1
tchaton Aug 24, 2021
28e855a
update on comments
tchaton Aug 24, 2021
9ba4394
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 24, 2021
240a7b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2021
33f1323
No longer need to seed
carmocca Aug 24, 2021
793ed8c
Merge branch 'master' into 3/n_inter_batch_parallelism
carmocca Aug 24, 2021
5b9d81e
update
tchaton Aug 24, 2021
2c44e7b
typo
tchaton Aug 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `on_train_epoch_end` from `Accelerator` ([#9035](https://github.com/PyTorchLightning/pytorch-lightning/pull/9035))


- Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052))


### Fixed

- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,11 @@ def log(
f" of {list(self._metric_attributes.values())}"
)

if is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) and batch_size is None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(
"With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
)

results.log(
self._current_fx_name,
name,
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.processors import IteratorBatchProcessor # noqa: F401
7 changes: 5 additions & 2 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,12 +547,15 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio
the keyword arguments for the training step
"""
# enable not needing to add opt_idx to training_step
step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])
step_kwargs = OrderedDict([("batch", batch)])

lightning_module = self.trainer.lightning_module
training_step_fx = getattr(lightning_module, "training_step")

if is_param_in_hook_signature(training_step_fx, "batch_idx"):
step_kwargs["batch_idx"] = batch_idx

if len(self.trainer.optimizers) > 1:
training_step_fx = getattr(lightning_module, "training_step")
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
if has_opt_idx_in_train_step:
if not lightning_module.automatic_optimization:
Expand Down
13 changes: 5 additions & 8 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Iterator, List, Optional, Sequence, Union
from typing import Any, List, Optional, Sequence, Union

from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader
Expand Down Expand Up @@ -98,17 +98,14 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader"""
void(*args, **kwargs)

dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader = self.trainer.data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=self.current_dataloader_idx
)
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
dataloader_iter = iter(dataloader)

dl_max_batches = self._max_batches[self.current_dataloader_idx]
dl_max_batches = self._max_batches[dataloader_idx]

dl_outputs = self.epoch_loop.run(
dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders
)
dl_outputs = self.epoch_loop.run(dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders)

# store batch level output per dataloader
if self.should_track_batch_outputs_for_epoch_end:
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self) -> None:
self._num_dataloaders: Optional[int] = None
self.outputs: List[STEP_OUTPUT] = []
self.batch_progress = Progress()
self.dataloader_iter: Optional[Iterator] = None

@property
def done(self) -> bool:
Expand Down Expand Up @@ -66,10 +67,12 @@ def on_run_start(
dl_max_batches: maximum number of batches the dataloader can produce
num_dataloaders: the total number of dataloaders
"""
void(dataloader_iter, dataloader_idx)
void(dataloader_idx)
self._dl_max_batches = dl_max_batches
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._num_dataloaders = num_dataloaders

self.dataloader_iter = enumerate(dataloader_iter, self.batch_progress.current.ready)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def advance(
self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
) -> None:
Expand All @@ -84,9 +87,9 @@ def advance(
Raises:
StopIteration: If the current batch is None
"""
void(dl_max_batches, num_dataloaders)
void(dataloader_iter, dl_max_batches, num_dataloaders)

batch_idx, (batch, _) = next(dataloader_iter)
batch_idx, (batch, _) = next(self.dataloader_iter)

if batch is None:
raise StopIteration
Expand Down
67 changes: 27 additions & 40 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Union

import torch

from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.processors import IteratorBatchProcessor
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT

# TODO: currently, the batch processor is only a loop when tbptt is enabled.
# As we introduce more specialized batch processors, we may want to choose a
# more suitable abstraction for them.
BATCH_LOOP_TYPE = Optional[Tuple[TrainingBatchLoop, IteratorBatchProcessor]]


class TrainingEpochLoop(loops.Loop):
"""
Expand All @@ -50,7 +45,7 @@ def __init__(self, min_steps: int, max_steps: int):
self.batch_progress = Progress()
self.scheduler_progress = SchedulerProgress()

self.batch_loop: BATCH_LOOP_TYPE = None
self.batch_loop: Optional[TrainingBatchLoop] = None
self.val_loop: Optional["loops.EvaluationLoop"] = None

self._results = ResultCollection(training=True)
Expand Down Expand Up @@ -81,7 +76,7 @@ def done(self) -> bool:

def connect(
self,
batch_loop: BATCH_LOOP_TYPE = None,
batch_loop: TrainingBatchLoop = None,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
val_loop: Optional["loops.EvaluationLoop"] = None,
) -> None:
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
Expand All @@ -102,14 +97,16 @@ def reset(self) -> None:
self.scheduler_progress.current.reset()
self.batch_loop.optim_progress.reset_on_epoch()

def on_run_start(self, *args: Any, **kwargs: Any) -> None:
def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
# hook
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")
self.trainer.fit_loop.epoch_progress.increment_started()

def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
self._prepare_dataloader_iter(dataloader_iter)

def advance(self, *args: Any, **kwargs: Any) -> None:
"""Runs a single training batch.

Args:
Expand All @@ -118,33 +115,18 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
Raises:
StopIteration: When the epoch is canceled by the user returning -1
"""
if isinstance(self.batch_loop, IteratorBatchProcessor):
# By contract, when taking `dataloader_iter` as an argument,
# `training_step` is responsible for reporting `is_last` in the
# result dict, which is used to determine the stop condition for
# the epoch. So as long as `advance` is invoked, it's correct to
# assume that there are more batches to be processed.
self.batch_progress.increment_ready()
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(dataloader_iter)
self.batch_progress.increment_processed()
is_last = batch_output.is_last
else:
_, (batch, is_last) = next(dataloader_iter)

if not self.trainer.data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch)

# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
self.batch_progress.increment_ready()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, self.batch_idx)

self.batch_progress.increment_processed()
batch_idx, (batch, is_last) = next(self.dataloader_iter)

if not self.trainer.data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch)

self.batch_progress.increment_ready()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)

self.batch_progress.increment_processed()

self.is_last_batch = is_last

Expand All @@ -162,8 +144,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True)

# hook
if not isinstance(self.batch_loop, IteratorBatchProcessor):
self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
self.trainer.call_hook("on_batch_end")
self.trainer.logger_connector.on_batch_end()

Expand Down Expand Up @@ -404,3 +385,9 @@ def _save_loggers_on_train_batch_end(self) -> None:
should_flush_logs = self.trainer.logger_connector.should_flush_logs
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
self.trainer.logger.save()

def _prepare_dataloader_iter(self, dataloader_iter: AbstractDataFetcher) -> None:
if not isinstance(dataloader_iter, DataLoaderIterDataFetcher):
dataloader_iter = enumerate(dataloader_iter, self.batch_idx + 1)
# restore iteration
self.dataloader_iter = dataloader_iter
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from contextlib import suppress
from typing import Iterator, Optional
from typing import Optional

from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
Expand Down
15 changes: 0 additions & 15 deletions pytorch_lightning/loops/processors/__init__.py

This file was deleted.

Loading