Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3/n inter batch parallelism #9052

Merged
merged 96 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from 78 commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
c7d22fe
update
tchaton Aug 20, 2021
60df25a
resolve tests
tchaton Aug 20, 2021
ea3311e
update
tchaton Aug 20, 2021
862349e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2021
d551859
update
tchaton Aug 20, 2021
6849c62
Merge branch '1/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 20, 2021
c190af5
update
tchaton Aug 20, 2021
7d67e42
update
tchaton Aug 20, 2021
fadeddc
update
tchaton Aug 20, 2021
53a32c1
update on comments
tchaton Aug 23, 2021
28b066f
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
tchaton Aug 23, 2021
1bc068f
update
tchaton Aug 23, 2021
b094b34
update on comments
tchaton Aug 23, 2021
27a924d
typo
tchaton Aug 23, 2021
617333d
resolve bug
tchaton Aug 23, 2021
6eeba87
update
tchaton Aug 23, 2021
0326a33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
06dae29
update on comments
tchaton Aug 23, 2021
9ea1953
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
becdb3e
update
tchaton Aug 23, 2021
4e578f2
Merge branch '1/n_inter_batch_parallelism' into 2/n_inter_batch_paral…
tchaton Aug 23, 2021
9bd8094
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
e9d6760
update
tchaton Aug 23, 2021
842a6ae
update
tchaton Aug 23, 2021
a70c052
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
d08976f
resolve tests
tchaton Aug 23, 2021
ae3ac40
resolve
tchaton Aug 23, 2021
7f0b4b4
update
tchaton Aug 23, 2021
e7272e3
update
tchaton Aug 23, 2021
99a8853
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
3bd09e1
update
tchaton Aug 23, 2021
c7b5ff0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
f571f8c
add teardown
tchaton Aug 23, 2021
c34c914
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 23, 2021
1b5b911
update
tchaton Aug 23, 2021
403ef3c
update on comments
tchaton Aug 23, 2021
96ee949
add back comment
tchaton Aug 23, 2021
68acd44
Fix diff
carmocca Aug 23, 2021
3c1eca6
Merge branch 'master' into 2/n_inter_batch_parallelism
tchaton Aug 23, 2021
a995764
update on comments
tchaton Aug 23, 2021
307f446
resolve tests
tchaton Aug 23, 2021
d69b6fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
09cfa02
Merge branch '2/n_inter_batch_parallelism' into 3/n_inter_batch_paral…
tchaton Aug 23, 2021
104256c
update
tchaton Aug 23, 2021
89fce20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
e415c60
update
tchaton Aug 23, 2021
126818b
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 23, 2021
d81c971
cleanup
tchaton Aug 23, 2021
d473903
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
1ed2f31
update
tchaton Aug 23, 2021
87432e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
fb98c86
update
tchaton Aug 23, 2021
54fda07
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 23, 2021
5bf9090
resolve memory leak
tchaton Aug 23, 2021
9920162
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
899c349
update on comments
tchaton Aug 23, 2021
eaf2fc7
Merge branch '2/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 23, 2021
589c717
resolve tests
tchaton Aug 23, 2021
38fa806
update
tchaton Aug 23, 2021
c2ec026
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
1e1a482
update
tchaton Aug 23, 2021
e26e27d
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 23, 2021
71755ca
update
tchaton Aug 23, 2021
8bbb957
update on comments
tchaton Aug 24, 2021
5ba95ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2021
63c11e1
update
tchaton Aug 24, 2021
fe92512
update
tchaton Aug 24, 2021
5a738de
update
tchaton Aug 24, 2021
a4a2171
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2021
29a7244
improve test
tchaton Aug 24, 2021
7f7b648
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 24, 2021
985e47a
update
tchaton Aug 24, 2021
95726a8
cleanup
tchaton Aug 24, 2021
fa75fb2
more cleanup
tchaton Aug 24, 2021
0f74e4f
resolve batch_idx sync
tchaton Aug 24, 2021
6667588
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2021
570edff
drop weird test
tchaton Aug 24, 2021
44d5052
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 24, 2021
89f372a
update
tchaton Aug 24, 2021
3fddde7
update on comments
tchaton Aug 24, 2021
cd1c8c1
resolve typing
tchaton Aug 24, 2021
3fb109d
updte
tchaton Aug 24, 2021
4de4da5
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 24, 2021
c00278d
update on comments
tchaton Aug 24, 2021
edaf099
update
tchaton Aug 24, 2021
9706854
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2021
1ad63bc
resolve comments
tchaton Aug 24, 2021
91f5954
updte
tchaton Aug 24, 2021
496f419
remove + 1
tchaton Aug 24, 2021
28e855a
update on comments
tchaton Aug 24, 2021
9ba4394
Merge branch '3/n_inter_batch_parallelism' of https://github.com/PyTo…
tchaton Aug 24, 2021
240a7b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2021
33f1323
No longer need to seed
carmocca Aug 24, 2021
793ed8c
Merge branch 'master' into 3/n_inter_batch_parallelism
carmocca Aug 24, 2021
5b9d81e
update
tchaton Aug 24, 2021
2c44e7b
typo
tchaton Aug 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `on_train_epoch_end` from `Accelerator` ([#9035](https://github.com/PyTorchLightning/pytorch-lightning/pull/9035))


- Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052))


### Fixed

- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,11 @@ def log(
f" of {list(self._metric_attributes.values())}"
)

if is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) and batch_size is None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(
"When the `dataloader_iter` is requested within the `training_step`, `batch_size` should be provided."
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)

results.log(
self._current_fx_name,
name,
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.processors import IteratorBatchProcessor # noqa: F401
6 changes: 5 additions & 1 deletion pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,13 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio
step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])

lightning_module = self.trainer.lightning_module
training_step_fx = getattr(self.trainer.lightning_module, "training_step")

if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
if not is_param_in_hook_signature(training_step_fx, "batch_idx"):
del step_kwargs["batch_idx"]
tchaton marked this conversation as resolved.
Show resolved Hide resolved

if len(self.trainer.optimizers) > 1:
training_step_fx = getattr(lightning_module, "training_step")
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
if has_opt_idx_in_train_step:
if not lightning_module.automatic_optimization:
Expand Down
11 changes: 4 additions & 7 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,14 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader"""
void(*args, **kwargs)

dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader = self.trainer.data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=self.current_dataloader_idx
)
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
dataloader_iter = iter(dataloader)

dl_max_batches = self._max_batches[self.current_dataloader_idx]
dl_max_batches = self._max_batches[dataloader_idx]

dl_outputs = self.epoch_loop.run(
dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders
)
dl_outputs = self.epoch_loop.run(dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders)

# store batch level output per dataloader
if self.should_track_batch_outputs_for_epoch_end:
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def on_run_start(
self._dl_max_batches = dl_max_batches
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._num_dataloaders = num_dataloaders

self.dataloader_iter = enumerate(dataloader_iter, self.batch_progress.current.ready)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def advance(
self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
) -> None:
Expand All @@ -86,7 +88,7 @@ def advance(
"""
void(dl_max_batches, num_dataloaders)

batch_idx, (batch, _) = next(dataloader_iter)
batch_idx, (batch, _) = next(self.dataloader_iter)

if batch is None:
raise StopIteration
Expand Down
65 changes: 26 additions & 39 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,13 @@

from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.processors import IteratorBatchProcessor
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT

# TODO: currently, the batch processor is only a loop when tbptt is enabled.
# As we introduce more specialized batch processors, we may want to choose a
# more suitable abstraction for them.
BATCH_LOOP_TYPE = Optional[Tuple[TrainingBatchLoop, IteratorBatchProcessor]]


class TrainingEpochLoop(loops.Loop):
"""
Expand All @@ -50,7 +45,7 @@ def __init__(self, min_steps: int, max_steps: int):
self.batch_progress = Progress()
self.scheduler_progress = SchedulerProgress()

self.batch_loop: BATCH_LOOP_TYPE = None
self.batch_loop: Optional[TrainingBatchLoop] = None
self.val_loop: Optional["loops.EvaluationLoop"] = None

self._results = ResultCollection(training=True)
Expand Down Expand Up @@ -81,7 +76,7 @@ def done(self) -> bool:

def connect(
self,
batch_loop: BATCH_LOOP_TYPE = None,
batch_loop: TrainingBatchLoop = None,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
val_loop: Optional["loops.EvaluationLoop"] = None,
) -> None:
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
Expand All @@ -102,14 +97,16 @@ def reset(self) -> None:
self.scheduler_progress.current.reset()
self.batch_loop.optim_progress.reset_on_epoch()

def on_run_start(self, *args: Any, **kwargs: Any) -> None:
def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
# hook
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")
self.trainer.fit_loop.epoch_progress.increment_started()

def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
self._prepare_dataloader_iter(dataloader_iter)

def advance(self, *args: Any, **kwargs: Any) -> None:
"""Runs a single training batch.

Args:
Expand All @@ -118,33 +115,18 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
Raises:
StopIteration: When the epoch is canceled by the user returning -1
"""
if isinstance(self.batch_loop, IteratorBatchProcessor):
# By contract, when taking `dataloader_iter` as an argument,
# `training_step` is responsible for reporting `is_last` in the
# result dict, which is used to determine the stop condition for
# the epoch. So as long as `advance` is invoked, it's correct to
# assume that there are more batches to be processed.
self.batch_progress.increment_ready()
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(dataloader_iter)
self.batch_progress.increment_processed()
is_last = batch_output.is_last
else:
_, (batch, is_last) = next(dataloader_iter)

if not self.trainer.data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch)

# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
self.batch_progress.increment_ready()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, self.batch_idx)

self.batch_progress.increment_processed()
batch_idx, (batch, is_last) = next(self.dataloader_iter)

if not self.trainer.data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch)

self.batch_progress.increment_ready()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)

self.batch_progress.increment_processed()

self.is_last_batch = is_last

Expand All @@ -162,8 +144,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True)

# hook
if not isinstance(self.batch_loop, IteratorBatchProcessor):
self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
self.trainer.call_hook("on_batch_end")
self.trainer.logger_connector.on_batch_end()

Expand Down Expand Up @@ -404,3 +385,9 @@ def _save_loggers_on_train_batch_end(self) -> None:
should_flush_logs = self.trainer.logger_connector.should_flush_logs
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
self.trainer.logger.save()

def _prepare_dataloader_iter(self, dataloader_iter: AbstractDataFetcher) -> None:
if not isinstance(dataloader_iter, DataLoaderIterDataFetcher):
dataloader_iter = enumerate(dataloader_iter, self.batch_idx + 1)
# restore iteration
self.dataloader_iter = dataloader_iter
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from contextlib import suppress
from typing import Iterator, Optional
from typing import Optional

from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
Expand Down
15 changes: 0 additions & 15 deletions pytorch_lightning/loops/processors/__init__.py

This file was deleted.

25 changes: 25 additions & 0 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature


class ConfigValidator:
Expand All @@ -34,6 +35,7 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None:
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, "val")
self.__verify_manual_optimization_support(model)
self.__check_training_step_requires_dataloader_iter(model)
elif self.trainer.state.fn == TrainerFn.VALIDATING:
self.__verify_eval_loop_configuration(model, "val")
elif self.trainer.state.fn == TrainerFn.TESTING:
Expand Down Expand Up @@ -128,3 +130,26 @@ def __verify_manual_optimization_support(self, model: "pl.LightningModule") -> N
f" Remove `Trainer(accumulate_grad_batches={self.trainer.accumulate_grad_batches})`"
" or switch to automatic optimization."
)

def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningModule"):
"""Check if the current `training_step` is requesting `dataloader_iter`."""
training_step_fx = getattr(model, "training_step")
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):

if is_overridden("on_train_batch_start", model):
raise MisconfigurationException(
"The model hook `on_train_batch_start` is not compatible with "
"taking a `dataloader_iter` argument in your `training_step`."
)

if is_overridden("on_train_batch_end", model):
raise MisconfigurationException(
"The model hook `on_train_batch_end` is not compatible with "
"taking a `dataloader_iter` argument in your `training_step`."
)

if model.truncated_bptt_steps > 0:
raise MisconfigurationException(
"The model taking a `dataloader_iter` argument in your `training_step` "
"is incompatible with `truncated_bptt_steps > 0`."
)
23 changes: 14 additions & 9 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,18 @@ def on_trainer_init(
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False

def _check_training_step_requires_dataloader_iter(self) -> bool:
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True)
return contains_dataloader_iter

def _select_data_fetcher(self) -> AbstractDataFetcher:
if self.trainer.sanity_checking:
return DataFetcher()

if self.trainer.training and self._check_training_step_requires_dataloader_iter():
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
if self.trainer.training and is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
rank_zero_warn(
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return DataLoaderIterDataFetcher()

elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
# note: this is an experimental feature
if not self.trainer.training_type_plugin.on_gpu:
Expand All @@ -124,9 +121,7 @@ def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0)
profiler=self.trainer.profiler,
)
setattr(self, f"{stage}_data_fetcher", data_fetcher)
if isinstance(data_fetcher, DataLoaderIterDataFetcher):
return data_fetcher
return enumerate(data_fetcher)
return data_fetcher

def prepare_data(self) -> None:
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
Expand Down Expand Up @@ -250,6 +245,16 @@ def detach_data(model: "pl.LightningModule") -> None:
if isinstance(loader, _PatchDataLoader):
loader.unpatch(model)

def teardown(self) -> None:
if self.train_data_fetcher:
self.train_data_fetcher.teardown()
if self.validate_data_fetcher:
self.validate_data_fetcher.teardown()
if self.test_data_fetcher:
self.test_data_fetcher.teardown()
if self.sanity_check_data_fetcher:
self.sanity_check_data_fetcher.teardown()


class _PatchDataLoader:
r"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,11 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT:
"""

def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
self.trainer._results.extract_batch_size(split_batch)
# when the user request `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher):
self.trainer._results.extract_batch_size(split_batch)

self._batch_idx = batch_idx
self._split_idx = split_idx

Expand Down
18 changes: 2 additions & 16 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loops import IteratorBatchProcessor, TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop
from pytorch_lightning.loops.fit_loop import FitLoop
Expand Down Expand Up @@ -920,28 +920,13 @@ def _load_checkpoint_weights(self):
rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}")
self.checkpoint_connector.restore_model_weights(self._ckpt_path)

def _maybe_switch_to_iterator_batch_processor(self, model: "pl.LightningModule") -> None:
training_step_fx = getattr(model, "training_step")
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
log.warning(
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
"this signature is experimental and the behavior may subject to change."
)
batch_loop = IteratorBatchProcessor(self, model)
self.fit_loop.epoch_loop.connect(batch_loop)
# FIXME: Move this logic to data_connector after removing `IteratorBatchProcessor`
self.data_connector.data_fetcher = DataLoaderIterDataFetcher()

def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
# clean hparams
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)

self.config_validator.verify_loop_configurations(model)

if self.training:
self._maybe_switch_to_iterator_batch_processor(model)

# attach model log function to callback
self.callback_connector.attach_model_logging_functions(model)

Expand Down Expand Up @@ -1077,6 +1062,7 @@ def _post_dispatch(self):
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
# which need to happen before.
self.accelerator.teardown()
self.data_connector.teardown()
self._active_loop.teardown()
self.logger_connector.teardown()

Expand Down
Loading