Skip to content

Commit

Permalink
Add a flavor of training_step that takes dataloader_iter as an argument
Browse files Browse the repository at this point in the history
  • Loading branch information
Yifu Wang committed Aug 9, 2021
1 parent 4928dc5 commit 81a8010
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 109 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/__about__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

_this_year = time.strftime("%Y")
__version__ = "1.5.0dev"
__version__ = "20210806"
__author__ = "William Falcon et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from pytorch_lightning.loops.base import Loop # noqa: F401
from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.batch import FlexibleOptimizationFlow, TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
1 change: 1 addition & 0 deletions pytorch_lightning/loops/batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.loops.batch.flexible_optimization_flow import FlexibleOptimizationFlow # noqa: F401
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401
161 changes: 161 additions & 0 deletions pytorch_lightning/loops/batch/flexible_optimization_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import OrderedDict
from copy import copy
from typing import Iterator, List, Optional, Tuple

import torch

import pytorch_lightning as pl
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.batch_loop_common import (
check_finite,
check_training_step_output,
process_training_step_output,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden

log = logging.getLogger(__name__)


class FlexibleOptimizationFlow:
"""
The flow for performing a training iteration when `training_step` needs access to the
dataloader. It is selected when the signature of `training_step` contains `dataloader_iter`:
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
The `training_step` is allowed to fetch multiple batches during one training iteration. The
framework provides minimum amount of automation with regards to model optimization (hence the
"flexible" in the name). The flexbilility allows for ease of experimentation with inter-batch
parallelism techniques.
This flow doesn't support `automatic_optimization` and `tbptt`. An error will be thrown if the
`LightningModule` or the `Trainer` is configured to use these features.
The `training_step` is responsible for reporting whether it has reached the last batch by
including an `is_last` field in the result dict. Failing to do so will result in an error.
The `training_step` should only optimize the model with one batch for the sake of API and
reporting consistency (TODO: consider removing this limitation).
Args:
trainer_ref: a reference to the trainer
model_ref: a reference to the lightning module (for config validation purposes only)
"""

def __init__(self, trainer_ref: "pl.Trainer", model_ref: "pl.LightningModule") -> None:
if is_overridden("on_train_batch_start", model_ref):
raise MisconfigurationException(
"The model hook `on_train_batch_start` is not compatible with FlexibleOptimizationFlow."
)
if is_overridden("on_train_batch_end", model_ref):
raise MisconfigurationException(
"The model hook `on_train_batch_end` is not compatible with FlexibleOptimizationFlow."
)
if is_overridden("tbptt_split_batch", model_ref):
raise MisconfigurationException(
"The model hook `tbptt_split_batch` is not compatible with FlexibleOptimizationFlow."
)
if model_ref.automatic_optimization:
raise MisconfigurationException("`automatic_optimization` is not support by FlexibleOptimizationFlow.")
if trainer_ref.accumulate_grad_batches != 1:
raise MisconfigurationException(
"`accumulate_grad_batches` can only be 1 when using FlexibleOptimizationFlow."
)

self.trainer_ref = trainer_ref

# The following field is not used by the flow since it doesn't support automatic
# optimization and tbptt. Initializing them regardless since they are currently expected by
# `FitLoop` or `TrainingEpochLoop`.
# TODO: come up with an abstraction for "batch processors" so they can be better decoupled
# with parent loops.
self.accumulated_loss: Optional[torch.Tensor] = None
self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=1)
self.optim_progress = OptimizationProgress()
self.split_idx: int = 0
self._skip_backward = False

def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
"""
Returns the number of active optimizers.
"""
return len(self.trainer_ref.optimizers)

def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, torch.optim.Optimizer]]:
"""
Returns the currently active optimizers.
Returns:
A list of tuples (opt_idx, optimizer) of currently active optimizers.
"""
return list(enumerate(self.trainer_ref.optimizers))

def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]:
"""
Args:
dataloader_iter: the iterator over the dataloader producing the new batch
"""
dataloader_iter = itertools.starmap(
lambda batch_idx, batch_with_is_last: batch_with_is_last[0], dataloader_iter
)

self.trainer_ref.logger_connector.on_batch_start()
response = self.trainer_ref.call_hook("on_batch_start")
if response == -1:
return AttributeDict(signal=-1)

self.trainer_ref.fit_loop.epoch_loop.batch_progress.increment_started()

# give the PL module a result for logging
model_ref = self.trainer_ref.lightning_module

with self.trainer_ref.profiler.profile("model_forward"):
# manually capture logged metrics
model_ref._current_fx_name = "training_step"
with self.trainer_ref.profiler.profile("training_step"):
step_kwargs = OrderedDict([("dataloader_iter", dataloader_iter)])
training_step_output = self.trainer_ref.accelerator.training_step(step_kwargs)
self.trainer_ref.accelerator.post_training_step()

training_step_output = self.trainer_ref.call_hook("training_step_end", training_step_output)
check_training_step_output(self.trainer_ref, training_step_output)

if training_step_output is None or "is_last" not in training_step_output:
raise MisconfigurationException(
"When using `FlexibleOptimizationFlow`, `training_step` must return a dict containing `is_last` "
"which indicated whether there are more batches to be processed."
)
is_last = training_step_output["is_last"]
training_step_output, _ = process_training_step_output(self.trainer_ref, training_step_output)

if self.trainer_ref.terminate_on_nan:
check_finite(self.trainer_ref, training_step_output.minimize)

batch_outputs = [[] for _ in range(len(self.trainer_ref.optimizers))]

batch_outputs[0].append(copy(training_step_output))
return AttributeDict(signal=0, training_step_output=batch_outputs, is_last=is_last)

def teardown(self) -> None:
"""
No-op. Only defined to comply with FitLoop's expectation.
"""
pass
93 changes: 10 additions & 83 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from contextlib import contextmanager
from copy import copy
from functools import partial, update_wrapper
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -27,13 +27,15 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.batch_loop_common import (
check_finite,
check_training_step_output,
process_training_step_output,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -253,32 +255,7 @@ def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
self._check_finite(opt_closure_result.loss)

def _check_training_step_output(self, training_step_output: STEP_OUTPUT) -> None:
"""Sanity checks that training produced a valid output and optimizer step has already been called in manual
optimization.
Args:
training_step_output: the output of the training step (before wrapping in an AttributeDict)
"""
if isinstance(training_step_output, Tensor) and not self.trainer.lightning_module.automatic_optimization:
if training_step_output.grad_fn is None:
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
elif self.trainer.lightning_module.automatic_optimization:
if not any(
(
isinstance(training_step_output, Tensor),
(isinstance(training_step_output, Mapping) and "loss" in training_step_output),
training_step_output is None,
)
):
raise MisconfigurationException(
"In automatic optimization, `training_step` must either return a Tensor, "
"a dict with key 'loss' or None (where the step will be skipped)."
)
check_finite(self.trainer, opt_closure_result.loss)

def _training_step(
self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor
Expand Down Expand Up @@ -308,9 +285,9 @@ def _training_step(

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

self._check_training_step_output(training_step_output)
check_training_step_output(self.trainer, training_step_output)

training_step_output = self._process_training_step_output(training_step_output)
training_step_output, self._hiddens = process_training_step_output(self.trainer, training_step_output)
if training_step_output is None:
return

Expand All @@ -323,45 +300,6 @@ def _training_step(
loss = closure_loss.detach().clone()
return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=training_step_output)

def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Optional[ResultCollection]:
"""Adds the :param:`training_step_output` to the trainer's results
Args:
training_step_output: the output of the training step (before wrapping into an AttributeDict)
Returns:
the updated results if the training_step's output was not None else None
"""
if training_step_output is None:
return None

results = self.trainer._results

loss = None
hiddens = None

# handle dict return
if isinstance(training_step_output, dict):
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
loss = training_step_output.get("loss")
hiddens = training_step_output.get("hiddens")
# detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time`
hiddens = apply_to_collection(hiddens, Tensor, lambda t: t.detach())
# use the setter instead of `dict.update` because it calls `detach` on the tensor items
results.extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}

# handle scalar return
elif isinstance(training_step_output, Tensor):
loss = training_step_output

# map to results under the hood
results.minimize = loss
self._hiddens = hiddens

if self.trainer.move_metrics_to_cpu:
results.cpu()
return results

def _optimizer_step(
self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable
) -> None:
Expand Down Expand Up @@ -529,7 +467,7 @@ def training_step_and_backward(

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
self._check_finite(result.loss)
check_finite(self.trainer, result.loss)

else:
self._warning_cache.warn(
Expand All @@ -538,17 +476,6 @@ def training_step_and_backward(

return result

def _check_finite(self, loss: Tensor) -> None:
"""Checks fotr finite parameters and loss values.
Args:
loss: the loss value to check to be finite
"""
if not torch.isfinite(loss).all():
raise ValueError(f"The loss returned in `training_step` is {loss}.")
model = self.trainer.lightning_module
detect_nan_parameters(model)

def backward(
self, result: STEP_OUTPUT, optimizer: Optional[torch.optim.Optimizer], *args: Any, **kwargs: Any
) -> None:
Expand Down
Loading

0 comments on commit 81a8010

Please sign in to comment.