Skip to content

Commit

Permalink
Add progress tracking on Loops - 2/n (#8362)
Browse files Browse the repository at this point in the history
* resolve issues

* update

* update

* update

* add more exceptions

* resolve bug

* update

* update

* update changelog

* resolve bug

* resolve comments

* update

* update

* update changelog

* update

* update

* remove space

* update

* add progress tracking to loops

* validate json

* update

* convert to dict for better readability

* validate reload

* update

* update

* update on comments

* remove deadcode

* clean changelog

* clean changelog

* update

* update on comments

* CHANGELOG

* CHANGELOG

* Update pytorch_lightning/loops/base.py

Co-authored-by: Carlos Mocholí <[email protected]>

* whitespace suggestions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make fault_tolerant_enabled protected

* whitespace fixes around Args

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* typo it's -> its

* fix copy-paste typo in progress docstring

* Delete classes

* Minor change

* docs

* protected get_loops_state

* merge restore_loops with restore_progress

* Fix tests after removals

* explicit save with trainer.save_checkpoint()

* handle optimization restart based on optimizer_idx

* update increments

* update val batch progress and remove iteration count

* update progress tracking for dataloader loops

* remove self.dataloader_idx from eval_epoch_loop

* add batch progress to predict loop

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* incorporate progress tracking for current_epoch

* Fix test

* Actually remove it

* Remove unused TrainingEpochProgress

* Fix optimization progress - missing scheduler

* Restarting changes

* Scheduler progress

* Unused property, reset on epoch

* Resolve FIXME

* Remove FIXME

* fix test_progress (wip)

* fix batch_progress.current.reset

* Hold off on split progress. Out of scope of this PR

* Unnecessary if

* fix structure in test_progress

* structure

* clean up unused variables in test_progress

* refactor naming and organization in test_progress

* Unnecessary variable

* Remove unnecessary diff

* Improve comment

* Undo typing change to avoid polluting everything with mypy fixes

* Fix and improve test_loops.py

* Fix and organize `test_loop_state_dict`

* Remove unnecessary checks in test

* Update test after disallowing updates on None attributes

* Typing

* Minor test cleanup

* Fix and move loop test

* Move test from progress to loops

* Reset the scheduler progress

* SchedulerProgress fix

* Consistent whitespace

* Fix final test

* Minor test changes

* One test to rule them all

* Formatting

* Rename and clean variables

* Shorter names

* Shorter scheduler name

* Fix optimizer step calculation for stop_batch=2

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove empty connects

* Update CHANGELOG

* Holy shit finally got the formula right

* Fix final thing!!!

* Do not check state dicts

* parametrize multiple_dataloader progress test

* Update CHANGELOG.md

Co-authored-by: Carlos Mocholi <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <[email protected]>
  • Loading branch information
5 people authored Jul 19, 2021
1 parent cbf71d0 commit 7bb810f
Show file tree
Hide file tree
Showing 19 changed files with 602 additions and 664 deletions.
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Progress tracking
* Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140))
* Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Add `{,load_}state_dict` to the progress tracking dataclasses ([#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140))
* Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244))
* Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))


- Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431))
Expand Down Expand Up @@ -92,6 +92,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fault-tolerant training
* Added `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948))
* Added `{,load_}state_dict` to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197))
* Set `Loop.restarting=False` at the end of the first iteration ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Save the loops state with the checkpoint (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Save a checkpoint to restore the state on exception (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))


- Added `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))
Expand Down Expand Up @@ -402,8 +405,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `optimizer` argument in `LightningModule.manual_backward()`; Toggling optimizers in manual optimization should be done using `LightningModule.{un}toggle_optimizer()` ([#8287](https://github.com/PyTorchLightning/pytorch-lightning/pull/8287))




### Fixed

- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877))
Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Loop(ABC):
"""

def __init__(self) -> None:
# TODO: replace by progress tracking
self.iteration_count: int = 0
self.restarting = False
self._trainer: Optional['pl.Trainer'] = None
Expand All @@ -56,8 +57,8 @@ def trainer(self) -> Optional['pl.Trainer']:

@trainer.setter
def trainer(self, trainer: 'pl.Trainer'):
"""Connect the Trainer to this loop and all children."""
if not isinstance(trainer, pl.Trainer) and trainer is not None:
"""Connects this loop's trainer and its children"""
if not isinstance(trainer, pl.Trainer):
raise MisconfigurationException(
f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}."
)
Expand Down Expand Up @@ -112,6 +113,7 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
self.advance(*args, **kwargs)
self.on_advance_end()
self.iteration_count += 1
self.restarting = False
except StopIteration:
break

Expand Down Expand Up @@ -158,7 +160,7 @@ def on_save_checkpoint(self) -> Dict:
"""
return {}

def on_load_checkpoint(self, state_dict: Dict):
def on_load_checkpoint(self, state_dict: Dict) -> None:
"""Called when loading a model checkpoint, use to reload loop state."""

def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = "") -> Dict:
Expand All @@ -183,14 +185,14 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] =

return destination

def load_state_dict(self, state_dict: Dict, prefix="", restart_progress: bool = True):
def load_state_dict(self, state_dict: Dict, prefix: str = "", restart_progress: bool = True) -> None:
""" Loads the state of this loop and all its children. """
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress)
for k, v in self.__dict__.items():
if isinstance(v, Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress)

def _load_from_state_dict(self, state_dict, prefix, restart_progress):
def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress: bool) -> None:
for k, v in self.__dict__.items():
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[prefix + k])
Expand Down
35 changes: 17 additions & 18 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, OptimizationProgress
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -50,7 +49,6 @@ def __init__(self) -> None:
self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20)
self.batch_idx: int = 0
self.split_idx: Optional[int] = None
self.progress = BatchProgress()
self.optim_progress = OptimizationProgress()

self._warning_cache: WarningCache = WarningCache()
Expand All @@ -59,21 +57,6 @@ def __init__(self) -> None:
self._remaining_splits: Optional[List[Any]] = None
self._skip_backward: bool = False

def connect(
self,
trainer: 'pl.Trainer',
*args: Any,
progress: Optional[BatchProgress] = None,
optim_progress: Optional[OptimizationProgress] = None,
**kwargs: Any
) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
if progress is not None:
self.progress = progress
if optim_progress is not None:
self.optim_progress = optim_progress

@property
def done(self) -> bool:
"""Returns if all batch splits have been processed already"""
Expand Down Expand Up @@ -109,6 +92,8 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
if response == -1:
return AttributeDict(signal=-1)

self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()

super().run(batch, batch_idx, dataloader_idx)
output = AttributeDict(signal=0, training_step_output=self.batch_outputs)
self.batch_outputs = None # free memory
Expand Down Expand Up @@ -149,6 +134,13 @@ def advance(self, batch, batch_idx, dataloader_idx):

if self.trainer.lightning_module.automatic_optimization:
for opt_idx, optimizer in self.get_active_optimizers(batch_idx):
# handle optimization restart
if self.restarting:
if opt_idx < self.optim_progress.optimizer_idx:
continue

self.optim_progress.optimizer_idx = opt_idx

result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
if result:
self.batch_outputs[opt_idx].append(result.training_step_output)
Expand Down Expand Up @@ -395,6 +387,8 @@ def _optimizer_step(
# wraps into LightningOptimizer only for running step
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)

self.optim_progress.optimizer.step.increment_ready()

# model hook
model_ref.optimizer_step(
self.trainer.current_epoch,
Expand All @@ -407,13 +401,17 @@ def _optimizer_step(
using_lbfgs=is_lbfgs,
)

self.optim_progress.optimizer.step.increment_completed()

def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
"""Calls the ``on_before_zero_grad`` hook.
Args:
optimizer: the current optimizer
"""
self.optim_progress.optimizer.zero_grad.increment_ready()
self.trainer.call_hook('on_before_zero_grad', optimizer)
self.optim_progress.optimizer.zero_grad.increment_started()

def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
"""Zeroes out all gradients of parameters optimized by the current optimizer.
Expand All @@ -424,6 +422,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer,
opt_idx: the index of the current optimizer
"""
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.optim_progress.optimizer.zero_grad.increment_completed()

def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
Expand Down
20 changes: 16 additions & 4 deletions pytorch_lightning/loops/dataloader/dataloader_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,21 @@
# limitations under the License.

from abc import abstractmethod
from typing import Sequence
from typing import Any, Sequence

from torch.utils.data import DataLoader

from pytorch_lightning.loops.base import Loop
from pytorch_lightning.trainer.progress import DataLoaderProgress


class DataLoaderLoop(Loop):
"""Base class to loop over all dataloaders"""

def __init__(self):
super().__init__()
self.dataloader_progress = DataLoaderProgress()

@property
@abstractmethod
def dataloaders(self) -> Sequence[DataLoader]:
Expand All @@ -31,7 +36,7 @@ def dataloaders(self) -> Sequence[DataLoader]:
@property
def current_dataloader_idx(self) -> int:
"""Returns the index of the current dataloader"""
return self.iteration_count
return self.dataloader_progress.current.ready - 1

@property
def current_dataloader(self) -> DataLoader:
Expand All @@ -46,8 +51,15 @@ def num_dataloaders(self) -> int:
@property
def done(self) -> bool:
"""Returns whether all dataloaders have been processed"""
return self.current_dataloader_idx >= self.num_dataloaders
return self.dataloader_progress.current.completed >= self.num_dataloaders

def reset(self) -> None:
"""Resets the internal state"""
self.iteration_count = 0
if not self.restarting:
self.dataloader_progress.current.reset()

def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
self.dataloader_progress.increment_ready()

def on_advance_end(self) -> None:
self.dataloader_progress.increment_completed()
16 changes: 5 additions & 11 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import EpochLoopProgress
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
Expand All @@ -33,8 +32,6 @@ class EvaluationLoop(DataLoaderLoop):
def __init__(self):
super().__init__()
self.outputs = []
self.progress = EpochLoopProgress()

self.epoch_loop = EvaluationEpochLoop()

self._results = ResultCollection(training=False)
Expand Down Expand Up @@ -66,19 +63,15 @@ def predictions(self):
"""Returns the predictions from all dataloaders"""
return self.epoch_loop.predictions

def connect(
self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any
) -> None:
def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
if progress is not None:
self.progress = progress
self.epoch_loop.connect(trainer, progress=self.progress.epoch)
self.epoch_loop.connect(trainer)

@property
def done(self) -> bool:
"""Returns whether all dataloaders are processed or evaluation should be skipped altogether"""
return (self.current_dataloader_idx >= len(self.dataloaders)) or self.skip
return super().done or self.skip

@property
def skip(self) -> bool:
Expand All @@ -88,14 +81,15 @@ def skip(self) -> bool:

def reset(self) -> None:
"""Resets the internal state of the loop"""
self.iteration_count = 0
self._max_batches = self.get_max_batches()
# bookkeeping
self.outputs = []

if isinstance(self._max_batches, int):
self._max_batches = [self._max_batches] * len(self.dataloaders)

super().reset()

def on_skip(self) -> List:
return []

Expand Down
16 changes: 2 additions & 14 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.trainer.progress import EpochLoopProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT

Expand All @@ -19,8 +18,6 @@ def __init__(self):
super().__init__()
self.predictions: Optional[List[List[Any]]] = None
self.epoch_batch_indices: Optional[List[List[int]]] = None
self.progress = EpochLoopProgress()

self.epoch_loop = PredictionEpochLoop()

self._results = None # for `trainer._results` access
Expand Down Expand Up @@ -67,23 +64,14 @@ def dataloaders(self) -> Sequence[DataLoader]:
"""Returns all prediction dataloaders"""
return self.trainer.predict_dataloaders

@property
def done(self) -> bool:
"""Whether prediction is finished: Max batches run or all dataloaders processed"""
return self.current_dataloader_idx >= len(self.dataloaders)

@property
def skip(self) -> bool:
return sum(self.max_batches) == 0

def connect(
self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any
) -> None:
def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
if progress is not None:
self.progress = progress
self.epoch_loop.connect(trainer, progress=self.progress.epoch)
self.epoch_loop.connect(trainer)

def reset(self) -> None:
"""Resets the internal state of the loop for a new run"""
Expand Down
Loading

0 comments on commit 7bb810f

Please sign in to comment.