Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Improve loops API 1/n #8334

Merged
merged 30 commits into from
Jul 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Refactored prediction loop interface; added new classes `PredictionLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077))
* Removed `pytorch_lightning/trainer/predict_loop.py` ([#8094](https://github.com/PyTorchLightning/pytorch-lightning/pull/8094))
* Moved result teardown to the loops ([#8245](https://github.com/PyTorchLightning/pytorch-lightning/pull/8245))

* Improve `Loop` API to better handle children `state_dict` and `progress` ([#8334](https://github.com/PyTorchLightning/pytorch-lightning/pull/8334))

- Refactored logging
* Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736))
Expand Down Expand Up @@ -339,6 +339,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `optimizer` argument in `LightningModule.manual_backward()`; Toggling optimizers in manual optimization should be done using `LightningModule.{un}toggle_optimizer()` ([#8287](https://github.com/PyTorchLightning/pytorch-lightning/pull/8287))




### Fixed

- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877))
Expand Down
89 changes: 72 additions & 17 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from deprecate import void

import pytorch_lightning as pl
from pytorch_lightning.trainer.progress import BaseProgress, Tracker
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -45,8 +47,24 @@ class Loop(ABC):

def __init__(self) -> None:
self.iteration_count: int = 0
self.trainer: Optional['pl.Trainer'] = None
self.restarting = False
self._trainer: Optional['pl.Trainer'] = None

@property
def trainer(self) -> Optional['pl.Trainer']:
return self._trainer

@trainer.setter
def trainer(self, trainer: 'pl.Trainer'):
"""Connect the Trainer to this loop and all children."""
if not isinstance(trainer, pl.Trainer) and trainer is not None:
raise MisconfigurationException(
f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}."
)
self._trainer = trainer
for v in self.__dict__.values():
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(v, Loop):
v.trainer = trainer

@property
@abstractmethod
Expand All @@ -61,10 +79,6 @@ def skip(self) -> bool:
def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects Loop with all the necessary things like connectors and accelerators."""
# TODO(@justusschock): Make the trainer a weakref/proxy
if not isinstance(trainer, pl.Trainer):
raise MisconfigurationException(
f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}."
)
self.trainer = trainer

def on_skip(self) -> Optional[Any]:
Expand All @@ -88,11 +102,7 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
if self.skip:
return self.on_skip()

if self.restarting:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So why do we not want the restarting property anymore?

And why did you leave it in the __init__?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still used in reset self.restarting. Check #8364.

Adrian convinced me it was cleaner to make it manual.

self.restarting is automatically set to True by the Trainer when calling loop load_state_dict function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where exactly is that PR setting restarting on the loops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.restore()
self.restarting = False
else:
self.reset()
self.reset()

self.on_run_start(*args, **kwargs)

Expand All @@ -108,9 +118,6 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
output = self.on_run_end()
return output

def restore(self) -> None:
"""Restore the internal state of the loop the beginning of run if restarting is ``True``."""

@abstractmethod
def reset(self) -> None:
"""Resets the internal state of the loop at the beginning of each call to :attr:`run`."""
Expand Down Expand Up @@ -142,9 +149,57 @@ def on_run_end(self) -> Any:
def teardown(self) -> None:
"""Use to release memory etc."""

def load_state_dict(self, state_dict: Dict) -> None:
"""Restore the loop state from the provided state_dict."""
def on_save_checkpoint(self) -> Dict:
"""
Called when saving a model checkpoint, use to persist loop state.

def state_dict(self) -> Dict:
"""Return the loop current states."""
Returns:
The current loop state.
"""
return {}

def on_load_checkpoint(self, state_dict: Dict):
"""Called when loading a model checkpoint, use to reload loop state."""

def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = "") -> Dict:
"""
The state dict is determined by the state and progress of this loop and all its children.

Args:
destination: An existing dictionary to update with this loop's state. By default a new dictionary
is returned.
prefix: A prefix for each key in the state dictionary
"""
if destination is None:
destination = {}

destination[prefix + "state_dict"] = self.on_save_checkpoint()

for k, v in self.__dict__.items():
if isinstance(v, BaseProgress):
destination[prefix + k] = v.state_dict()
elif isinstance(v, Loop):
v.state_dict(destination, prefix + k + '.')

return destination

def load_state_dict(self, state_dict: Dict, prefix="", restart_progress: bool = True):
""" Loads the state of this loop and all its children. """
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress)
for k, v in self.__dict__.items():
if isinstance(v, Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress)

def _load_from_state_dict(self, state_dict, prefix, restart_progress):
for k, v in self.__dict__.items():
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[prefix + k])
if restart_progress:

def restart(tracker: Tracker):
tracker.reset_on_restart()

apply_to_collection(v, Tracker, restart)

self.on_load_checkpoint(state_dict[prefix + "state_dict"])
self.restarting = True
27 changes: 20 additions & 7 deletions pytorch_lightning/trainer/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@dataclass
class _DataclassStateDictMixin:
class BaseProgress:

def state_dict(self) -> dict:
return asdict(self)
Expand All @@ -25,14 +25,14 @@ def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)

@classmethod
def from_state_dict(cls, state_dict: dict) -> "_DataclassStateDictMixin":
def from_state_dict(cls, state_dict: dict) -> "BaseProgress":
obj = cls()
obj.load_state_dict(state_dict)
return obj


@dataclass
class Tracker(_DataclassStateDictMixin):
class Tracker(BaseProgress):
"""
Track an event's progress.

Expand Down Expand Up @@ -70,9 +70,22 @@ def __repr__(self):
args = [f"{k}={v}" for k, v in self.__dict__.items() if v is not None]
return f"{self.__class__.__name__}({', '.join(args)})"

def reset_on_restart(self):
"""Reset the progress on restart"""
value = self.completed if self.processed is None else self.processed

if self.ready is not None:
self.ready = value
if self.started is not None:
self.started = value
if self.processed is not None:
self.processed = value
if self.completed is not None:
self.completed = value


@dataclass
class Progress(_DataclassStateDictMixin):
class Progress(BaseProgress):
"""
Track aggregated and current progress.

Expand Down Expand Up @@ -150,7 +163,7 @@ def load_state_dict(self, state_dict: dict) -> None:


@dataclass
class OptimizerProgress(_DataclassStateDictMixin):
class OptimizerProgress(BaseProgress):
"""
Track optimizer progress.

Expand All @@ -172,7 +185,7 @@ def load_state_dict(self, state_dict: dict) -> None:


@dataclass
class OptimizationProgress(_DataclassStateDictMixin):
class OptimizationProgress(BaseProgress):
"""
Track optimization progress.

Expand Down Expand Up @@ -203,7 +216,7 @@ def load_state_dict(self, state_dict: dict) -> None:


@dataclass
class EpochLoopProgress(_DataclassStateDictMixin):
class EpochLoopProgress(BaseProgress):
"""
Tracks epoch loop progress.
These counters are local to a trainer rank. By default, they are not globally synced across all ranks.
Expand Down
Loading