diff --git a/CHANGELOG.md b/CHANGELOG.md index 313a790ce5468..2fdcc224d8eb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -199,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `on_train_epoch_end` from `Accelerator` ([#9035](https://github.com/PyTorchLightning/pytorch-lightning/pull/9035)) +- Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052)) + + ### Fixed - Fixed save/load/resume from checkpoint for DeepSpeed Plugin ( diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 01f9b4fc025fe..096333388c3b1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -456,6 +456,15 @@ def log( f" of {list(self._metric_attributes.values())}" ) + if ( + self.trainer.training + and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) + and batch_size is None + ): + raise MisconfigurationException( + "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided." + ) + results.log( self._current_fx_name, name, diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py index c9775ed44155e..b7eb47167d26f 100644 --- a/pytorch_lightning/loops/__init__.py +++ b/pytorch_lightning/loops/__init__.py @@ -17,4 +17,3 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401 from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401 from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401 -from pytorch_lightning.loops.processors import IteratorBatchProcessor # noqa: F401 diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 163d7681f29a8..29517ad306eba 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -547,12 +547,15 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio the keyword arguments for the training step """ # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) + step_kwargs = OrderedDict([("batch", batch)]) lightning_module = self.trainer.lightning_module + training_step_fx = getattr(lightning_module, "training_step") + + if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2): + step_kwargs["batch_idx"] = batch_idx if len(self.trainer.optimizers) > 1: - training_step_fx = getattr(lightning_module, "training_step") has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: if not lightning_module.automatic_optimization: diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index d15236dc7694c..7f06f5cd4ff63 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -98,17 +98,14 @@ def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader""" void(*args, **kwargs) + dataloader_idx: int = self.current_dataloader_idx dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - dataloader = self.trainer.data_connector.get_profiled_dataloader( - dataloader, dataloader_idx=self.current_dataloader_idx - ) + dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx) dataloader_iter = iter(dataloader) - dl_max_batches = self._max_batches[self.current_dataloader_idx] + dl_max_batches = self._max_batches[dataloader_idx] - dl_outputs = self.epoch_loop.run( - dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders - ) + dl_outputs = self.epoch_loop.run(dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders) # store batch level output per dataloader if self.should_track_batch_outputs_for_epoch_end: diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index eb3f9dad58bcf..e4770084c84cd 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -19,6 +19,7 @@ from torch import Tensor from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -37,6 +38,7 @@ def __init__(self) -> None: self._num_dataloaders: Optional[int] = None self.outputs: List[STEP_OUTPUT] = [] self.batch_progress = Progress() + self.dataloader_iter: Optional[Iterator] = None @property def done(self) -> bool: @@ -66,10 +68,12 @@ def on_run_start( dl_max_batches: maximum number of batches the dataloader can produce num_dataloaders: the total number of dataloaders """ - void(dataloader_iter, dataloader_idx) + void(dataloader_idx) self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders + self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_progress.current.ready) + def advance( self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int ) -> None: @@ -84,9 +88,9 @@ def advance( Raises: StopIteration: If the current batch is None """ - void(dl_max_batches, num_dataloaders) + void(dataloader_iter, dl_max_batches, num_dataloaders) - batch_idx, (batch, _) = next(dataloader_iter) + batch_idx, (batch, _) = next(self.dataloader_iter) if batch is None: raise StopIteration diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 741a05cd5701e..43d51fe0027c6 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -11,24 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Union import torch from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop -from pytorch_lightning.loops.processors import IteratorBatchProcessor +from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import Progress, SchedulerProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import STEP_OUTPUT -# TODO: currently, the batch processor is only a loop when tbptt is enabled. -# As we introduce more specialized batch processors, we may want to choose a -# more suitable abstraction for them. -BATCH_LOOP_TYPE = Optional[Tuple[TrainingBatchLoop, IteratorBatchProcessor]] - class TrainingEpochLoop(loops.Loop): """ @@ -50,7 +45,7 @@ def __init__(self, min_steps: int, max_steps: int): self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() - self.batch_loop: BATCH_LOOP_TYPE = None + self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) @@ -81,7 +76,7 @@ def done(self) -> bool: def connect( self, - batch_loop: BATCH_LOOP_TYPE = None, + batch_loop: TrainingBatchLoop = None, val_loop: Optional["loops.EvaluationLoop"] = None, ) -> None: """Optionally connect a custom batch or validation loop to this training epoch loop.""" @@ -102,14 +97,16 @@ def reset(self) -> None: self.scheduler_progress.current.reset() self.batch_loop.optim_progress.reset_on_epoch() - def on_run_start(self, *args: Any, **kwargs: Any) -> None: + def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None: # hook self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") self.trainer.fit_loop.epoch_progress.increment_started() - def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: + self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_idx + 1) + + def advance(self, *args: Any, **kwargs: Any) -> None: """Runs a single training batch. Args: @@ -118,33 +115,18 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: Raises: StopIteration: When the epoch is canceled by the user returning -1 """ - if isinstance(self.batch_loop, IteratorBatchProcessor): - # By contract, when taking `dataloader_iter` as an argument, - # `training_step` is responsible for reporting `is_last` in the - # result dict, which is used to determine the stop condition for - # the epoch. So as long as `advance` is invoked, it's correct to - # assume that there are more batches to be processed. - self.batch_progress.increment_ready() - with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.batch_loop.run(dataloader_iter) - self.batch_progress.increment_processed() - is_last = batch_output.is_last - else: - _, (batch, is_last) = next(dataloader_iter) - - if not self.trainer.data_connector.train_data_fetcher.store_on_device: - with self.trainer.profiler.profile("training_batch_to_device"): - batch = self.trainer.accelerator.batch_to_device(batch) - - # ------------------------------------ - # TRAINING_STEP + TRAINING_STEP_END - # ------------------------------------ - self.batch_progress.increment_ready() - - with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.batch_loop.run(batch, self.batch_idx) - - self.batch_progress.increment_processed() + batch_idx, (batch, is_last) = next(self.dataloader_iter) + + if not self.trainer.data_connector.train_data_fetcher.store_on_device: + with self.trainer.profiler.profile("training_batch_to_device"): + batch = self.trainer.accelerator.batch_to_device(batch) + + self.batch_progress.increment_ready() + + with self.trainer.profiler.profile("run_training_batch"): + batch_output = self.batch_loop.run(batch, batch_idx) + + self.batch_progress.increment_processed() self.is_last_batch = is_last @@ -162,8 +144,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) # hook - if not isinstance(self.batch_loop, IteratorBatchProcessor): - self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0) + self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0) self.trainer.call_hook("on_batch_end") self.trainer.logger_connector.on_batch_end() diff --git a/pytorch_lightning/loops/processors/__init__.py b/pytorch_lightning/loops/processors/__init__.py deleted file mode 100644 index 9fcbe9e82dca8..0000000000000 --- a/pytorch_lightning/loops/processors/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pytorch_lightning.loops.processors.iterator_batch_processor import IteratorBatchProcessor # noqa: F401 diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py deleted file mode 100644 index c1981173215ae..0000000000000 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from collections import OrderedDict -from typing import Any, Dict, Iterator, List, Optional, Tuple - -import torch - -import pytorch_lightning as pl -from pytorch_lightning.loops.utilities import ( - _check_training_step_output, - _process_training_step_output, - check_finite_loss, -) -from pytorch_lightning.trainer.progress import OptimizationProgress -from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import AttributeDict -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature - -log = logging.getLogger(__name__) - - -class IteratorBatchProcessor: - """ - The processor for performing a training iteration when ``training_step`` needs access to the - dataloader. It is selected when the signature of ``training_step`` contains ``dataloader_iter``: - - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - - The ``training_step`` is allowed to fetch multiple batches during one training iteration. The - framework provides minimum amount of automation with regards to model optimization. The - flexibility allows for ease of experimentation with inter-batch parallelism techniques. - - This processor doesn't support ``automatic_optimization`` and ``tbptt``. An error will be thrown - if the ``LightningModule`` or the ``Trainer`` is configured to use these features. - - The ``training_step`` is responsible for reporting whether it has reached the last batch by - including an ``is_last`` field in the result dict. Failing to do so will result in an error. - - The ``training_step`` should only optimize the model with one batch for the sake of API and - reporting consistency (TODO: consider removing this limitation). - - Args: - trainer: a reference to the trainer - model: a reference to the lightning module (for config validation purposes only) - """ - - def __init__(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None: - if is_overridden("on_train_batch_start", model): - raise MisconfigurationException( - "The model hook `on_train_batch_start` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - if is_overridden("on_train_batch_end", model): - raise MisconfigurationException( - "The model hook `on_train_batch_end` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - if is_overridden("tbptt_split_batch", model): - raise MisconfigurationException( - "The model hook `tbptt_split_batch` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - if trainer.accumulate_grad_batches != 1: - raise MisconfigurationException( - "`accumulate_grad_batches` can only be 1 when your " - "`training_step` takes `dataloader_iter` as an argument." - ) - - self.trainer = trainer - - # The following field is not used by the processor since it doesn't support automatic - # optimization and tbptt. Initializing them regardless since they are currently expected by - # `FitLoop` or `TrainingEpochLoop`. - # TODO: come up with an abstraction for "batch processors" so they can be better decoupled - # with parent loops. - self.accumulated_loss: Optional[torch.Tensor] = None - self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=1) - self.optim_progress = OptimizationProgress() - self.split_idx: int = 0 - self._skip_backward = False - - def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int: - """ - Returns the number of active optimizers. - """ - return len(self.trainer.optimizers) - - def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, torch.optim.Optimizer]]: - """ - Returns the currently active optimizers. - - Returns: - A list of tuples (opt_idx, optimizer) of currently active optimizers. - """ - return list(enumerate(self.trainer.optimizers)) - - def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: - """ - Args: - dataloader_iter: the iterator over the dataloader producing the new batch - """ - batch_idx, (dataloader_iter, is_last) = next(dataloader_iter) - - self.trainer.logger_connector.on_batch_start() - response = self.trainer.call_hook("on_batch_start") - if response == -1: - return AttributeDict(signal=-1) - - self.trainer.fit_loop.epoch_loop.batch_progress.increment_started() - - # give the PL module a result for logging - model = self.trainer.lightning_module - # manually capture logged metrics - model._current_fx_name = "training_step" - step_kwargs = self._build_kwargs(dataloader_iter, batch_idx) - - with self.trainer.profiler.profile("model_forward"): - with self.trainer.profiler.profile("training_step"): - training_step_output = self.trainer.accelerator.training_step(step_kwargs) - self.trainer.accelerator.post_training_step() - - training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - _check_training_step_output(self.trainer.lightning_module, training_step_output) - - training_step_output, _ = _process_training_step_output(self.trainer, training_step_output) - - if self.trainer.terminate_on_nan: - check_finite_loss(self.trainer.lightning_module, training_step_output.minimize) - - batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - if training_step_output: - batch_outputs[0].append(training_step_output) - - return AttributeDict(signal=0, training_step_output=batch_outputs, is_last=is_last) - - def teardown(self) -> None: - """ - No-op. Only defined to comply with FitLoop's expectation. - """ - pass - - # FIXME: To be deleted in next PR. - def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, Any]: - """Builds the keyword arguments for training_step - - Args: - dataloader_iter: The dataloader to pass - batch_idx: the index of the current batch - - Returns: - An ordered dict with the keyword arguments for the training step - """ - # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict([("dataloader_iter", dataloader_iter)]) - - training_step_fx = getattr(self.trainer.lightning_module, "training_step") - if is_param_in_hook_signature(training_step_fx, "batch_idx"): - step_kwargs["batch_idx"] = batch_idx - - return step_kwargs diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 1f0636d8b6cda..89ba5cd07d459 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Mapping, Optional, Tuple +from typing import Iterator, Mapping, Optional, Tuple import torch @@ -20,6 +20,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -102,3 +103,11 @@ def _process_training_step_output( if trainer.move_metrics_to_cpu: results.cpu() return results, hiddens + + +def _prepare_dataloader_iter(dataloader_iter: Iterator, batch_idx: int) -> Iterator: + """Attach the dataloader""" + if not isinstance(dataloader_iter, DataLoaderIterDataFetcher): + dataloader_iter = enumerate(dataloader_iter, batch_idx) + # restore iteration + return dataloader_iter diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 07548f9c49074..d9c341c5dfaeb 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -16,6 +16,7 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature class ConfigValidator: @@ -34,6 +35,7 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None: self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, "val") self.__verify_manual_optimization_support(model) + self.__check_training_step_requires_dataloader_iter(model) elif self.trainer.state.fn == TrainerFn.VALIDATING: self.__verify_eval_loop_configuration(model, "val") elif self.trainer.state.fn == TrainerFn.TESTING: @@ -128,3 +130,26 @@ def __verify_manual_optimization_support(self, model: "pl.LightningModule") -> N f" Remove `Trainer(accumulate_grad_batches={self.trainer.accumulate_grad_batches})`" " or switch to automatic optimization." ) + + def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningModule"): + """Check if the current `training_step` is requesting `dataloader_iter`.""" + training_step_fx = getattr(model, "training_step") + if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): + + if is_overridden("on_train_batch_start", model): + raise MisconfigurationException( + "The model hook `on_train_batch_start` is not compatible with " + "taking a `dataloader_iter` argument in your `training_step`." + ) + + if is_overridden("on_train_batch_end", model): + raise MisconfigurationException( + "The model hook `on_train_batch_end` is not compatible with " + "taking a `dataloader_iter` argument in your `training_step`." + ) + + if model.truncated_bptt_steps > 0: + raise MisconfigurationException( + "The model taking a `dataloader_iter` argument in your `training_step` " + "is incompatible with `truncated_bptt_steps > 0`." + ) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 8d337b972dce2..cdf4ce1c471bf 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -91,21 +91,18 @@ def on_trainer_init( self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs self.trainer._is_data_prepared = False - def _check_training_step_requires_dataloader_iter(self) -> bool: - training_step_fx = getattr(self.trainer.lightning_module, "training_step") - contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True) - return contains_dataloader_iter - def _select_data_fetcher(self) -> AbstractDataFetcher: if self.trainer.sanity_checking: return DataFetcher() - if self.trainer.training and self._check_training_step_requires_dataloader_iter(): + training_step_fx = getattr(self.trainer.lightning_module, "training_step") + if self.trainer.training and is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): rank_zero_warn( "Found `dataloader_iter` argument in the `training_step`. Note that the support for " "this signature is experimental and the behavior is subject to change." ) return DataLoaderIterDataFetcher() + elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": # note: this is an experimental feature if not self.trainer.training_type_plugin.on_gpu: @@ -124,9 +121,7 @@ def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) profiler=self.trainer.profiler, ) setattr(self, f"{stage}_data_fetcher", data_fetcher) - if isinstance(data_fetcher, DataLoaderIterDataFetcher): - return data_fetcher - return enumerate(data_fetcher) + return data_fetcher def prepare_data(self) -> None: # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 @@ -250,6 +245,16 @@ def detach_data(model: "pl.LightningModule") -> None: if isinstance(loader, _PatchDataLoader): loader.unpatch(model) + def teardown(self) -> None: + if self.train_data_fetcher: + self.train_data_fetcher.teardown() + if self.validate_data_fetcher: + self.validate_data_fetcher.teardown() + if self.test_data_fetcher: + self.test_data_fetcher.teardown() + if self.sanity_check_data_fetcher: + self.sanity_check_data_fetcher.teardown() + class _PatchDataLoader: r""" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 672695885ebf9..a965699510689 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -199,7 +199,11 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: """ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: - self.trainer._results.extract_batch_size(split_batch) + # when the user request `dataloader_iter`, we can't track the batch_size + # and this is left to user responsibility. + if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher): + self.trainer._results.extract_batch_size(split_batch) + self._batch_idx = batch_idx self._split_idx = split_idx diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ac66b1083f2ed..19ccf3935a168 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -28,7 +28,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.loops import IteratorBatchProcessor, TrainingBatchLoop, TrainingEpochLoop +from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop @@ -77,12 +77,10 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS log = logging.getLogger(__name__) @@ -920,18 +918,6 @@ def _load_checkpoint_weights(self): rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}") self.checkpoint_connector.restore_model_weights(self._ckpt_path) - def _maybe_switch_to_iterator_batch_processor(self, model: "pl.LightningModule") -> None: - training_step_fx = getattr(model, "training_step") - if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): - log.warning( - "Found `dataloader_iter` argument in the `training_step`. Note that the support for " - "this signature is experimental and the behavior may subject to change." - ) - batch_loop = IteratorBatchProcessor(self, model) - self.fit_loop.epoch_loop.connect(batch_loop) - # FIXME: Move this logic to data_connector after removing `IteratorBatchProcessor` - self.data_connector.data_fetcher = DataLoaderIterDataFetcher() - def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): @@ -939,9 +925,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.config_validator.verify_loop_configurations(model) - if self.training: - self._maybe_switch_to_iterator_batch_processor(model) - # attach model log function to callback self.callback_connector.attach_model_logging_functions(model) @@ -1077,6 +1060,7 @@ def _post_dispatch(self): # these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns # which need to happen before. self.accelerator.teardown() + self.data_connector.teardown() self._active_loop.teardown() self.logger_connector.teardown() diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 72f54a891cde3..d37cd3a9c1e6f 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -390,7 +390,6 @@ def __iter__(self) -> "StepFuncDataLoaderIter": def __next__(self) -> Any: try: data = next(self.iterator) - # FIXME: Link this to `batch_idx`. self.data_fetcher.fetched += 1 return data except StopIteration: diff --git a/pytorch_lightning/utilities/signature_utils.py b/pytorch_lightning/utilities/signature_utils.py index 5c7e468d84738..05045e98d3af6 100644 --- a/pytorch_lightning/utilities/signature_utils.py +++ b/pytorch_lightning/utilities/signature_utils.py @@ -12,15 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Callable +from typing import Callable, Optional -def is_param_in_hook_signature(hook_fx: Callable, param: str, explicit: bool = False) -> bool: +def is_param_in_hook_signature( + hook_fx: Callable, param: str, explicit: bool = False, min_args: Optional[int] = None +) -> bool: """ Args: hook_fx: the hook callable param: the name of the parameter to check explicit: whether the parameter has to be explicitly declared + min_args: whether the `signature` as at least `min_args` parameters """ hook_params = list(inspect.signature(hook_fx).parameters) - return param in hook_params or (not explicit and "args" in hook_params) + return ( + param in hook_params + or (not explicit and "args" in hook_params) + or (isinstance(min_args, int) and len(hook_params) >= min_args) + ) diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index f76e76b2f9dd9..6a8192ef0149e 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -17,7 +17,7 @@ import torch import pytorch_lightning as pl -from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from tests.helpers import BoringModel @@ -27,8 +27,6 @@ def test_finetuning_with_resume_from_checkpoint(tmpdir): This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test """ - seed_everything(4) - checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1) class ExtendedBoringModel(BoringModel): @@ -75,9 +73,6 @@ def validation_step(self, batch, batch_idx): results.append(deepcopy(trainer.callback_metrics)) best_model_paths.append(trainer.checkpoint_callback.best_model_path) - for idx in range(len(results) - 1): - assert results[idx]["val_loss"] > results[idx + 1]["val_loss"] - for idx, best_model_path in enumerate(best_model_paths): if idx == 0: assert best_model_path.endswith(f"epoch=0{idx}.ckpt") diff --git a/tests/loops/test_inter_batch_parallelism.py b/tests/loops/test_inter_batch_parallelism.py deleted file mode 100644 index 00bc7049b0a29..0000000000000 --- a/tests/loops/test_inter_batch_parallelism.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import time -from statistics import mean -from typing import Iterator - -import torch -from torch.utils.data import DataLoader, IterableDataset - -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.utilities.types import STEP_OUTPUT -from tests.helpers.runif import RunIf - - -def count_cycles_per_ms() -> float: - """ - Measure and return approximate number of cycles per millisecond for torch.cuda._sleep - - Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py - """ - - def measure() -> float: - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - torch.cuda._sleep(1000000) - end.record() - end.synchronize() - cycles_per_ms = 1000000 / start.elapsed_time(end) - return cycles_per_ms - - # Get 10 values and remove the 2 max and 2 min and return the avg. - # This is to avoid system disturbance that skew the results, e.g. - # the very first cuda call likely does a bunch of init, which takes - # much longer than subsequent calls. - # - # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs - # and seems to return stable values. Therefore, we enable caching - # using lru_cache decorator above. - num = 10 - vals = [] - for _ in range(num): - vals.append(measure()) - vals = sorted(vals) - return mean(vals[2 : num - 2]) - - -_CYCLES_PER_MS = int(count_cycles_per_ms()) if torch.cuda.is_available() else 0 -_BATCH_SIZE = 128 -_EMB_SZ = 100 -_EMB_DIM = 64 - - -class RandomSparseDataset(IterableDataset): - def __init__(self, emb_dim: int, batch_size: int, count: int) -> None: - self.emb_dim = emb_dim - self.batch_size = batch_size - self.count = count - - def __iter__(self): - for _ in range(self.count): - yield torch.randint(self.emb_dim, [self.batch_size]) - - -class ToyDLRMModel(LightningModule): - """ - A toy model for mimicking the communication overhead of sharded embedding - modules in DLRM models. - - DLRM models can be trained in a DDP-like fashion, where each trainer - receives different batches (embedding indices in this example). Since the - embeddings are sharded across trainers, the lookup process involves (1) - routing the indices to the trainer that possesses the corresponding - embeddings (2) performing local lookup (3) routing the embedding lookup - result back. - - The toy model doesn't actually performs index/result routing. It simply - uses torch.cuda._sleep() to mimic the cost of the communication op (i.e. - a2a). - """ - - def __init__(self): - super().__init__() - self.automatic_optimization = False - self.local_embedding = torch.nn.Embedding(_EMB_SZ, _EMB_DIM) - - def _route_indices(self, batch: torch.Tensor, non_blocking=False): - """ - This can be parallelized across different batches since it's model - weight independent. - - Why not run this in dataloader/datamodule? - - The routing logic depends on how model is sharded - - Putting this in data preprocessor changes the semantic of the model - """ - torch.cuda._sleep(_CYCLES_PER_MS * 1_000) - if not non_blocking: - torch.cuda.synchronize() - return batch - - def _route_result(self, result: torch.Tensor, non_blocking=False): - torch.cuda._sleep(_CYCLES_PER_MS * 1_000) - if not non_blocking: - torch.cuda.synchronize() - return result - - def forward(self, indices: torch.Tensor): - local_indices = self._route_indices(indices) - result = self.local_embedding(local_indices) - return self._route_result(result) - - def training_step(self, batch: torch.Tensor, batch_idx: int) -> STEP_OUTPUT: - return self.forward(batch) - - def configure_optimizers(self): - return torch.optim.SGD(self.local_embedding.parameters(), lr=0.1) - - def train_dataloader(self): - return DataLoader(RandomSparseDataset(_EMB_DIM, _BATCH_SIZE, 5)) - - -class AsyncToyDLRMModel(ToyDLRMModel): - def __init__(self): - super().__init__() - self.comm_stream = torch.cuda.Stream() - self.batch_i = None - self.batch_i_ready = torch.cuda.Event() - - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - if self.batch_i is None: - self.batch_i = next(dataloader_iter) - with torch.cuda.stream(self.comm_stream): - self._route_indices(self.batch_i, non_blocking=True) - self.batch_i_ready.record() - - # Invariant: the routing for batch[i] has been kicked off - is_last = False - batch_ip1 = None - batch_ip1_ready = torch.cuda.Event() - try: - batch_ip1 = next(dataloader_iter) - with torch.cuda.stream(self.comm_stream): - self._route_indices(batch_ip1, non_blocking=True) - batch_ip1_ready.record() - except StopIteration: - is_last = True - - self.batch_i_ready.wait() - - result = self.local_embedding(self.batch_i) - self._route_result(result) - - self.batch_i = batch_ip1 - self.batch_i_ready = batch_ip1_ready - - return {"is_last": is_last} - - -@RunIf(min_gpus=1) -def test_inter_batch_parallelism(tmpdir): - """ - Verify the speedup of a simple inter-batch parallelization use case enabled - by exposing `dataloader_iter` to `training_step`. - """ - begin_time = time.time() - m = AsyncToyDLRMModel() - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - trainer.fit(m) - async_duration = time.time() - begin_time - - begin_time = time.time() - m = ToyDLRMModel() - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - trainer.fit(m) - sync_duration = time.time() - begin_time - - # We expect 2x speedup. However, we only assert that the async - # training_step is faster in order to avoid flaky tests - assert async_duration < sync_duration, "Expect `AsyncToyDLRMModel` to train faster than `ToyDLRMModel`." diff --git a/tests/loops/test_iterator_batch_processor.py b/tests/loops/test_iterator_batch_processor.py deleted file mode 100644 index 2cd6a172f6941..0000000000000 --- a/tests/loops/test_iterator_batch_processor.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Iterator - -import pytest -from torch.utils.data import DataLoader - -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import STEP_OUTPUT -from tests.helpers import BoringModel, RandomDataset - -_BATCH_SIZE = 32 -_DATASET_LEN = 64 - - -class DummyWaitable: - def __init__(self, val: Any) -> None: - self.val = val - - def wait(self) -> Any: - return self.val - - -class AsyncBoringModel(BoringModel): - def __init__(self) -> None: - super().__init__() - self.automatic_optimization = False - self.batch_i_handle = None - self.num_batches_processed = 0 - - def _async_op(self, batch: Any) -> DummyWaitable: - return DummyWaitable(val=batch) - - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - if self.batch_i_handle is None: - batch_i_raw = next(dataloader_iter) - self.batch_i_handle = self._async_op(batch_i_raw) - - # Invariant: _async_op for batch[i] has been initiated - batch_ip1_handle = None - is_last = False - try: - batch_ip1_raw = next(dataloader_iter) - batch_ip1_handle = self._async_op(batch_ip1_raw) - except StopIteration: - is_last = True - - batch_i = self.batch_i_handle.wait() - - pred = self.layer(batch_i) - loss = self.loss(batch_i, pred) - loss.backward() - self.optimizers().step() - self.optimizers().zero_grad() - - self.batch_i_handle = batch_ip1_handle - self.num_batches_processed += 1 - - return {"loss": loss, "is_last": is_last} - - def train_dataloader(self): - return DataLoader(RandomDataset(_BATCH_SIZE, _DATASET_LEN)) - - -def test_training_step_with_dataloader_access(tmpdir) -> None: - """ - A baseline functional test for `training_step` with dataloader access. - """ - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = AsyncBoringModel() - trainer.fit(m) - assert m.num_batches_processed == _DATASET_LEN, f"Expect all {_DATASET_LEN} batches to be processed." - - -def test_stop_iteration(tmpdir) -> None: - """ - Verify that when `StopIteration` is raised within `training_step`, `fit()` - terminiates as expected. - """ - EXPECT_NUM_BATCHES_PROCESSED = 2 - - class TestModel(AsyncBoringModel): - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - output = super().training_step(dataloader_iter) - if self.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED: - raise StopIteration() - return output - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = TestModel() - trainer.fit(m) - assert ( - m.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED - ), "Expect {EXPECT_NUM_BATCHES_PROCESSED} batches to be processed." - - -def test_on_train_batch_start_overridden(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `on_train_batch_start` is overridden on the `LightningModule`. - """ - - class InvalidModel(AsyncBoringModel): - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - pass - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - -def test_on_train_batch_end_overridden(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `on_train_batch_end` is overridden on the `LightningModule`. - """ - - class InvalidModel(AsyncBoringModel): - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - pass - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - -def test_tbptt_split_batch_overridden(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `tbptt_split_batch` is overridden on the `LightningModule`. - """ - - class InvalidModel(AsyncBoringModel): - def tbptt_split_batch(self, batch, split_size): - pass - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - -def test_accumulate_grad_batches(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `accumulate_grad_batches` is not set to 1. - """ - trainer = Trainer(max_epochs=1, accumulate_grad_batches=2, default_root_dir=tmpdir) - m = AsyncBoringModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - -def test_is_last_not_set(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when `training_step` - doesn't include "is_last" in the result dict. - """ - - class InvalidModel(AsyncBoringModel): - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - output = super().training_step(dataloader_iter) - del output["is_last"] - return output - - trainer = Trainer(max_epochs=1, accumulate_grad_batches=2, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 670e8b4842a89..80896f6fa450c 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -419,7 +419,7 @@ class TestModel(ManualOptModel): called = False - def on_after_backward(self): + def on_before_optimizer_step(self, *args): self.called = True norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) if not (torch.isinf(norm) or torch.isnan(norm)): diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index b351165e03fd8..bf54bbae83568 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -13,7 +13,7 @@ # limitations under the License. import os from time import time -from typing import Any +from typing import Any, Iterator from unittest import mock import pytest @@ -25,7 +25,8 @@ from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher -from tests.helpers.boring_model import BoringModel +from pytorch_lightning.utilities.types import STEP_OUTPUT +from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -125,7 +126,8 @@ def measure() -> float: return sum(stats) / len(stats) -BATCH_SIZE = 128 +BATCH_SIZE = 32 +DATASET_LEN = 64 EMB_SZ = 100 EMB_DIM = 64 @@ -176,6 +178,7 @@ def test_dataloader(self): def test_trainer_num_prefetch_batches(tmpdir): model = RecommenderModel() + trainer_kwargs = dict( default_root_dir=tmpdir, max_epochs=1, @@ -190,8 +193,8 @@ def test_trainer_num_prefetch_batches(tmpdir): trainer = Trainer(**trainer_kwargs) trainer.fit(model) t1 = time() - global_step = trainer.global_step assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher) + global_step = trainer.global_step torch.cuda.synchronize() @@ -199,9 +202,9 @@ def test_trainer_num_prefetch_batches(tmpdir): trainer = Trainer(**trainer_kwargs) trainer.fit(model) t3 = time() + assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) assert global_step == trainer.global_step == 4 - assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) ratio = (t3 - t2) / (t1 - t0) assert ratio > 1.1, ratio @@ -218,7 +221,7 @@ def __init__(self, *args, automatic_optimization: bool = False, **kwargs): def training_step(self, dataloader_iter, batch_idx): assert self.count == batch_idx - assert isinstance(self.trainer.data_connector.data_fetcher, DataLoaderIterDataFetcher) + assert isinstance(self.trainer.data_connector.train_data_fetcher, DataLoaderIterDataFetcher) # fetch 2 batches self.batches.append(next(dataloader_iter)) self.batches.append(next(dataloader_iter)) @@ -227,7 +230,10 @@ def training_step(self, dataloader_iter, batch_idx): assert isinstance(batch, torch.Tensor) or batch is None self.count += 2 if self.automatic_optimization: - return super().training_step(batch, 0) + loss = super().training_step(batch, 0) + with pytest.raises(MisconfigurationException, match="dataloader_iter"): + self.log("train_loss", loss["loss"]) + self.log("train_loss", loss["loss"], batch_size=1) else: opt = self.optimizers() output = self(batch) @@ -236,10 +242,152 @@ def training_step(self, dataloader_iter, batch_idx): loss.backward() opt.step() - training_epoch_end = None + def training_epoch_end(self, *_): + assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33 + assert self.trainer.data_connector.train_data_fetcher.fetched == 64 + assert self.count == 64 model = TestModel(automatic_optimization=automatic_optimization) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - trainer.data_connector.data_fetcher = DataLoaderIterDataFetcher() trainer.fit(model) - assert model.count == 64 + + +class DummyWaitable: + def __init__(self, val: Any) -> None: + self.val = val + + def wait(self) -> Any: + return self.val + + +class AsyncBoringModel(BoringModel): + def __init__(self) -> None: + super().__init__() + self.automatic_optimization = False + self.batch_i_handle = None + self.num_batches_processed = 0 + + def _async_op(self, batch: Any) -> DummyWaitable: + return DummyWaitable(val=batch) + + def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: + if self.batch_i_handle is None: + batch_i_raw = next(dataloader_iter) + self.batch_i_handle = self._async_op(batch_i_raw) + + # Invariant: _async_op for batch[i] has been initiated + batch_ip1_handle = None + is_last = False + try: + batch_ip1_raw = next(dataloader_iter) + batch_ip1_handle = self._async_op(batch_ip1_raw) + except StopIteration: + is_last = True + + batch_i = self.batch_i_handle.wait() + + pred = self.layer(batch_i) + loss = self.loss(batch_i, pred) + loss.backward() + self.optimizers().step() + self.optimizers().zero_grad() + + self.batch_i_handle = batch_ip1_handle + self.num_batches_processed += 1 + + return {"loss": loss, "is_last": is_last} + + def train_dataloader(self): + return DataLoader(RandomDataset(BATCH_SIZE, DATASET_LEN)) + + +def test_training_step_with_dataloader_access(tmpdir) -> None: + """ + A baseline functional test for `training_step` with dataloader access. + """ + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) + m = AsyncBoringModel() + trainer.fit(m) + assert m.num_batches_processed == DATASET_LEN, f"Expect all {DATASET_LEN} batches to be processed." + + +@pytest.mark.parametrize("trigger_stop_iteration", [False, True]) +def test_stop_iteration(trigger_stop_iteration, tmpdir): + """ + Verify that StopIteration properly terminates the training when this is trigged + from the current `dataloader_iter` + """ + EXPECT_NUM_BATCHES_PROCESSED = 2 + + class TestModel(AsyncBoringModel): + def __init__(self, trigger_stop_iteration) -> None: + super().__init__() + self.trigger_stop_iteration = trigger_stop_iteration + + def training_step(self, dataloader_iter: Iterator, *args) -> STEP_OUTPUT: + output = super().training_step(dataloader_iter) + if self.trigger_stop_iteration and args[0] == EXPECT_NUM_BATCHES_PROCESSED: + raise StopIteration + return output + + def train_dataloader(self): + if self.trigger_stop_iteration: + return DataLoader(RandomDataset(BATCH_SIZE, 2 * EXPECT_NUM_BATCHES_PROCESSED)) + return DataLoader(RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED)) + + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) + m = TestModel(trigger_stop_iteration) + trainer.fit(m) + expected = EXPECT_NUM_BATCHES_PROCESSED + if trigger_stop_iteration: + expected *= 2 + assert m.num_batches_processed == expected + + +def test_on_train_batch_start_overridden(tmpdir) -> None: + """ + Verify that a `MisconfigurationException` is raised when + `on_train_batch_start` is overridden on the `LightningModule`. + """ + + class InvalidModel(AsyncBoringModel): + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + pass + + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) + m = InvalidModel() + with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_start` is not compatible with"): + trainer.fit(m) + + +def test_on_train_batch_end_overridden(tmpdir) -> None: + """ + Verify that a `MisconfigurationException` is raised when + `on_train_batch_end` is overridden on the `LightningModule`. + """ + + class InvalidModel(AsyncBoringModel): + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + pass + + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) + m = InvalidModel() + with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_end` is not compatible with"): + trainer.fit(m) + + +def test_tbptt_split_batch_overridden(tmpdir) -> None: + """ + Verify that a `MisconfigurationException` is raised when + `tbptt_split_batch` is overridden on the `LightningModule`. + """ + + class InvalidModel(AsyncBoringModel): + def __init__(self) -> None: + super().__init__() + self.truncated_bptt_steps = 2 + + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) + m = InvalidModel() + with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."): + trainer.fit(m)