Skip to content

Commit

Permalink
3/n inter batch parallelism (#9052)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
3 people authored Aug 24, 2021
1 parent b9443a0 commit f959b13
Show file tree
Hide file tree
Showing 21 changed files with 275 additions and 665 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `on_train_epoch_end` from `Accelerator` ([#9035](https://github.com/PyTorchLightning/pytorch-lightning/pull/9035))


- Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052))


### Fixed

- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,15 @@ def log(
f" of {list(self._metric_attributes.values())}"
)

if (
self.trainer.training
and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True)
and batch_size is None
):
raise MisconfigurationException(
"With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
)

results.log(
self._current_fx_name,
name,
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.processors import IteratorBatchProcessor # noqa: F401
7 changes: 5 additions & 2 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,12 +547,15 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio
the keyword arguments for the training step
"""
# enable not needing to add opt_idx to training_step
step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])
step_kwargs = OrderedDict([("batch", batch)])

lightning_module = self.trainer.lightning_module
training_step_fx = getattr(lightning_module, "training_step")

if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
step_kwargs["batch_idx"] = batch_idx

if len(self.trainer.optimizers) > 1:
training_step_fx = getattr(lightning_module, "training_step")
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
if has_opt_idx_in_train_step:
if not lightning_module.automatic_optimization:
Expand Down
11 changes: 4 additions & 7 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,14 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader"""
void(*args, **kwargs)

dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader = self.trainer.data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=self.current_dataloader_idx
)
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
dataloader_iter = iter(dataloader)

dl_max_batches = self._max_batches[self.current_dataloader_idx]
dl_max_batches = self._max_batches[dataloader_idx]

dl_outputs = self.epoch_loop.run(
dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders
)
dl_outputs = self.epoch_loop.run(dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders)

# store batch level output per dataloader
if self.should_track_batch_outputs_for_epoch_end:
Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import Tensor

from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand All @@ -37,6 +38,7 @@ def __init__(self) -> None:
self._num_dataloaders: Optional[int] = None
self.outputs: List[STEP_OUTPUT] = []
self.batch_progress = Progress()
self.dataloader_iter: Optional[Iterator] = None

@property
def done(self) -> bool:
Expand Down Expand Up @@ -66,10 +68,12 @@ def on_run_start(
dl_max_batches: maximum number of batches the dataloader can produce
num_dataloaders: the total number of dataloaders
"""
void(dataloader_iter, dataloader_idx)
void(dataloader_idx)
self._dl_max_batches = dl_max_batches
self._num_dataloaders = num_dataloaders

self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_progress.current.ready)

def advance(
self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
) -> None:
Expand All @@ -84,9 +88,9 @@ def advance(
Raises:
StopIteration: If the current batch is None
"""
void(dl_max_batches, num_dataloaders)
void(dataloader_iter, dl_max_batches, num_dataloaders)

batch_idx, (batch, _) = next(dataloader_iter)
batch_idx, (batch, _) = next(self.dataloader_iter)

if batch is None:
raise StopIteration
Expand Down
61 changes: 21 additions & 40 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Union

import torch

from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.processors import IteratorBatchProcessor
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT

# TODO: currently, the batch processor is only a loop when tbptt is enabled.
# As we introduce more specialized batch processors, we may want to choose a
# more suitable abstraction for them.
BATCH_LOOP_TYPE = Optional[Tuple[TrainingBatchLoop, IteratorBatchProcessor]]


class TrainingEpochLoop(loops.Loop):
"""
Expand All @@ -50,7 +45,7 @@ def __init__(self, min_steps: int, max_steps: int):
self.batch_progress = Progress()
self.scheduler_progress = SchedulerProgress()

self.batch_loop: BATCH_LOOP_TYPE = None
self.batch_loop: Optional[TrainingBatchLoop] = None
self.val_loop: Optional["loops.EvaluationLoop"] = None

self._results = ResultCollection(training=True)
Expand Down Expand Up @@ -81,7 +76,7 @@ def done(self) -> bool:

def connect(
self,
batch_loop: BATCH_LOOP_TYPE = None,
batch_loop: TrainingBatchLoop = None,
val_loop: Optional["loops.EvaluationLoop"] = None,
) -> None:
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
Expand All @@ -102,14 +97,16 @@ def reset(self) -> None:
self.scheduler_progress.current.reset()
self.batch_loop.optim_progress.reset_on_epoch()

def on_run_start(self, *args: Any, **kwargs: Any) -> None:
def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
# hook
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")
self.trainer.fit_loop.epoch_progress.increment_started()

def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_idx + 1)

def advance(self, *args: Any, **kwargs: Any) -> None:
"""Runs a single training batch.
Args:
Expand All @@ -118,33 +115,18 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
Raises:
StopIteration: When the epoch is canceled by the user returning -1
"""
if isinstance(self.batch_loop, IteratorBatchProcessor):
# By contract, when taking `dataloader_iter` as an argument,
# `training_step` is responsible for reporting `is_last` in the
# result dict, which is used to determine the stop condition for
# the epoch. So as long as `advance` is invoked, it's correct to
# assume that there are more batches to be processed.
self.batch_progress.increment_ready()
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(dataloader_iter)
self.batch_progress.increment_processed()
is_last = batch_output.is_last
else:
_, (batch, is_last) = next(dataloader_iter)

if not self.trainer.data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch)

# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
self.batch_progress.increment_ready()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, self.batch_idx)

self.batch_progress.increment_processed()
batch_idx, (batch, is_last) = next(self.dataloader_iter)

if not self.trainer.data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch)

self.batch_progress.increment_ready()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)

self.batch_progress.increment_processed()

self.is_last_batch = is_last

Expand All @@ -162,8 +144,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True)

# hook
if not isinstance(self.batch_loop, IteratorBatchProcessor):
self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
self.trainer.call_hook("on_batch_end")
self.trainer.logger_connector.on_batch_end()

Expand Down
15 changes: 0 additions & 15 deletions pytorch_lightning/loops/processors/__init__.py

This file was deleted.

Loading

0 comments on commit f959b13

Please sign in to comment.