Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Stateful PrecisionPlugin #11638

Merged
merged 20 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def optimizer_step(
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
optimizer.step(**kwargs)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if "amp_scaling_state" in checkpoint:
amp.load_state_dict(checkpoint["amp_scaling_state"])
def state_dict(self) -> Dict[str, Any]:
return amp.state_dict()

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["amp_scaling_state"] = amp.state_dict()
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if state_dict:
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
amp.load_state_dict(state_dict)
12 changes: 6 additions & 6 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def forward_context(self) -> Generator[None, None, None]:
with self.autocast_context_manager():
yield

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if self.scaler is not None and "native_amp_scaling_state" in checkpoint:
self.scaler.load_state_dict(checkpoint["native_amp_scaling_state"])

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
def state_dict(self) -> Dict[str, Any]:
if self.scaler is not None:
checkpoint["native_amp_scaling_state"] = self.scaler.state_dict()
return self.scaler.state_dict()
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if self.scaler is not None and state_dict:
self.scaler.load_state_dict(state_dict)
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import contextlib
from functools import partial
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -242,3 +242,9 @@ def teardown(self) -> None:

It is the right place to release memory and free other resources.
"""

def state_dict(self) -> Dict[str, Any]:
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
pass
18 changes: 16 additions & 2 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,17 @@ def restore_training_state(self) -> None:
return

# restore precision plugin (scaler etc.)
self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint)
prec_plugin = self.trainer.precision_plugin
prec_plugin.on_load_checkpoint(self._loaded_checkpoint)
if prec_plugin.__class__.__name__ in self._loaded_checkpoint:
prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__name__])

# old checkpoints compatibility
# should we raise error and force user to run utilities/upgrade_checkpoint instead?
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
if "amp_scaling_state" in self._loaded_checkpoint:
prec_plugin.load_state_dict(self._loaded_checkpoint["amp_scaling_state"])
if "native_amp_scaling_state" in self._loaded_checkpoint:
prec_plugin.load_state_dict(self._loaded_checkpoint["native_amp_scaling_state"])
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved

# restore loops and their progress
self.restore_loops()
Expand Down Expand Up @@ -372,7 +382,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
lr_schedulers.append(config.scheduler.state_dict())
checkpoint["lr_schedulers"] = lr_schedulers

self.trainer.precision_plugin.on_save_checkpoint(checkpoint)
# precision plugin
prec_plugin = self.trainer.precision_plugin
checkpoint[prec_plugin.__class__.__name__] = self.trainer.precision_plugin.state_dict()
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved

# dump hyper-parameters
if model.hparams:
Expand All @@ -389,6 +401,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
model.on_save_checkpoint(checkpoint)
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_save_checkpoint(checkpoint)
if not weights_only:
self.trainer.precision_plugin.on_save_checkpoint(checkpoint)
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved

# TODO: remove this in v1.8.
environment = self.trainer._accelerator_connector.cluster_environment
Expand Down