Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a flavor of training_step that takes dataloader_iter as an argument #8807

Merged
merged 16 commits into from
Aug 16, 2021
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a flavor of `training_step` that takes `dataloader_iter` as an argument ([#8807](https://github.com/PyTorchLightning/pytorch-lightning/pull/8807))

- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.processors import IteratorBatchProcessor # noqa: F401
89 changes: 6 additions & 83 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from contextlib import contextmanager
from copy import copy
from functools import partial, update_wrapper
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -26,14 +26,12 @@

from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.utilities import check_finite, check_training_step_output, process_training_step_output
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -253,32 +251,7 @@ def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
self._check_finite(opt_closure_result.loss)

def _check_training_step_output(self, training_step_output: STEP_OUTPUT) -> None:
"""Sanity checks that training produced a valid output and optimizer step has already been called in manual
optimization.

Args:
training_step_output: the output of the training step (before wrapping in an AttributeDict)

"""
if isinstance(training_step_output, Tensor) and not self.trainer.lightning_module.automatic_optimization:
if training_step_output.grad_fn is None:
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
elif self.trainer.lightning_module.automatic_optimization:
if not any(
(
isinstance(training_step_output, Tensor),
(isinstance(training_step_output, Mapping) and "loss" in training_step_output),
training_step_output is None,
)
):
raise MisconfigurationException(
"In automatic optimization, `training_step` must either return a Tensor, "
"a dict with key 'loss' or None (where the step will be skipped)."
)
check_finite(self.trainer.lightning_module, opt_closure_result.loss)

def _training_step(
self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor
Expand Down Expand Up @@ -308,9 +281,9 @@ def _training_step(

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

self._check_training_step_output(training_step_output)
check_training_step_output(self.trainer.lightning_module, training_step_output)

training_step_output = self._process_training_step_output(training_step_output)
training_step_output, self._hiddens = process_training_step_output(self.trainer, training_step_output)
if training_step_output is None:
return

Expand All @@ -323,45 +296,6 @@ def _training_step(
loss = closure_loss.detach().clone()
return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=training_step_output)

def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Optional[ResultCollection]:
"""Adds the :param:`training_step_output` to the trainer's results

Args:
training_step_output: the output of the training step (before wrapping into an AttributeDict)

Returns:
the updated results if the training_step's output was not None else None
"""
if training_step_output is None:
return None

results = self.trainer._results

loss = None
hiddens = None

# handle dict return
if isinstance(training_step_output, dict):
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
loss = training_step_output.get("loss")
hiddens = training_step_output.get("hiddens")
# detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time`
hiddens = apply_to_collection(hiddens, Tensor, lambda t: t.detach())
# use the setter instead of `dict.update` because it calls `detach` on the tensor items
results.extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}

# handle scalar return
elif isinstance(training_step_output, Tensor):
loss = training_step_output

# map to results under the hood
results.minimize = loss
self._hiddens = hiddens

if self.trainer.move_metrics_to_cpu:
results.cpu()
return results

def _optimizer_step(
self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable
) -> None:
Expand Down Expand Up @@ -531,7 +465,7 @@ def training_step_and_backward(

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
self._check_finite(result.loss)
check_finite(self.trainer.lightning_module, result.loss)

else:
self._warning_cache.warn(
Expand All @@ -540,17 +474,6 @@ def training_step_and_backward(

return result

def _check_finite(self, loss: Tensor) -> None:
"""Checks fotr finite parameters and loss values.

Args:
loss: the loss value to check to be finite
"""
if not torch.isfinite(loss).all():
raise ValueError(f"The loss returned in `training_step` is {loss}.")
model = self.trainer.lightning_module
detect_nan_parameters(model)

def backward(
self, result: STEP_OUTPUT, optimizer: Optional[torch.optim.Optimizer], *args: Any, **kwargs: Any
) -> None:
Expand Down
63 changes: 43 additions & 20 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterator, List, Optional, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import torch

from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.processors import IteratorBatchProcessor
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

# TODO: currently, the batch processor is only a loop when tbptt is enabled.
# As we introduce more specialized batch processors, we may want to choose a
# more suitable abstraction for them.
BATCH_LOOP_TYPE = Optional[Tuple[TrainingBatchLoop, IteratorBatchProcessor]]


class TrainingEpochLoop(loops.Loop):
"""
Expand All @@ -38,14 +44,15 @@ def __init__(self, min_steps: int, max_steps: int):
super().__init__()
self.min_steps: int = min_steps
self.max_steps: int = max_steps

self.global_step: int = 0
# the total batch index across all epochs
self.total_batch_idx: int = 0
self.is_last_batch: Optional[bool] = None
self.batch_progress = Progress()
self.scheduler_progress = SchedulerProgress()

self.batch_loop: Optional[TrainingBatchLoop] = None
self.batch_loop: BATCH_LOOP_TYPE = None
self.val_loop: Optional["loops.EvaluationLoop"] = None

self._results = ResultCollection(training=True)
Expand All @@ -70,7 +77,9 @@ def done(self) -> bool:
return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch)

def connect(
self, batch_loop: Optional[TrainingBatchLoop] = None, val_loop: Optional["loops.EvaluationLoop"] = None
self,
batch_loop: BATCH_LOOP_TYPE = None,
val_loop: Optional["loops.EvaluationLoop"] = None,
) -> None:
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
if batch_loop is not None:
Expand Down Expand Up @@ -107,21 +116,34 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
Raises:
StopIteration: When the epoch is canceled by the user returning -1
"""
_, (batch, is_last) = next(dataloader_iter)
self.is_last_batch = is_last

# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=self._dataloader_idx)

self.batch_progress.increment_ready()
if isinstance(self.batch_loop, IteratorBatchProcessor):
# By contract, when taking `dataloader_iter` as an argument,
# `training_step` is responsible for reporting `is_last` in the
# result dict, which is used to determine the stop condition for
# the epoch. So as long as `advance` is invoked, it's correct to
# assume that there are more batches to be processed.
self.batch_progress.increment_ready()
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(dataloader_iter)
self.batch_progress.increment_processed()
is_last = batch_output.is_last
else:
_, (batch, is_last) = next(dataloader_iter)

# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=self._dataloader_idx)

self.batch_progress.increment_ready()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, self.batch_idx, self._dataloader_idx)

self.batch_progress.increment_processed()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, self.batch_idx, self._dataloader_idx)

self.batch_progress.increment_processed()
self.is_last_batch = is_last

# when returning -1 from train_step, we end epoch early
if batch_output.signal == -1:
Expand All @@ -137,9 +159,10 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True)

# hook
self.trainer.call_hook(
"on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, self._dataloader_idx
)
if not isinstance(self.batch_loop, IteratorBatchProcessor):
self.trainer.call_hook(
"on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, self._dataloader_idx
)
self.trainer.call_hook("on_batch_end")
self.trainer.logger_connector.on_batch_end()

yifuwang marked this conversation as resolved.
Show resolved Hide resolved
yifuwang marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/loops/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.loops.processors.iterator_batch_processor import IteratorBatchProcessor # noqa: F401
Loading