Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove should_rank_save_checkpoint property from TTP #11070

Merged
merged 10 commits into from
Dec 21, 2021
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the `mp_queue` attribute from `DDPSpawnPlugin` and `TPUSpawnPlugin` ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))


- Removed unnessesary `_move_optimizer_state` method overrides from `TPUSpawnPlugin` and `SingleTPUPlugin` ([#10849](https://github.com/PyTorchLightning/pytorch-lightning/pull/10849))
- Removed unnecessary `_move_optimizer_state` method overrides from `TPUSpawnPlugin` and `SingleTPUPlugin` ([#10849](https://github.com/PyTorchLightning/pytorch-lightning/pull/10849))


- Removed `should_rank_save_checkpoint` property from `TrainingTypePlugin` ([#11070](https://github.com/PyTorchLightning/pytorch-lightning/pull/11070))


- Removed `model_sharded_context` method from `Accelerator` ([#10886](https://github.com/PyTorchLightning/pytorch-lightning/pull/10886))
Expand Down
12 changes: 1 addition & 11 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, Optional
from typing import Optional

import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
Expand All @@ -22,7 +22,6 @@
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _PATH

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -75,15 +74,6 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None:
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)

def teardown(self) -> None:
# TPU teardown
os.environ.pop("PT_XLA_DEBUG", None)
Expand Down
16 changes: 11 additions & 5 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,17 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)
four4fish marked this conversation as resolved.
Show resolved Hide resolved
# `xla_model.save` needs to be called on all ranks. It internally checks if the local rank is 0
self.checkpoint_io.save_checkpoint(checkpoint, filepath)

def remove_checkpoint(self, filepath: _PATH) -> None:
four4fish marked this conversation as resolved.
Show resolved Hide resolved
"""Remove checkpoint filepath from the filesystem.

Args:
filepath: Path to checkpoint
"""
if self.local_rank == 0:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.checkpoint_io.remove_checkpoint(filepath)

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Expand All @@ -331,10 +341,6 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
def teardown(self) -> None:
os.environ.pop("PT_XLA_DEBUG", None)

@property
def should_rank_save_checkpoint(self) -> bool:
return self.local_rank == 0
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def register_plugins(cls, plugin_registry: Dict) -> None:
plugin_registry.register("tpu_spawn_debug", cls, description="TPUSpawn Plugin with `debug` as True", debug=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,17 +397,17 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""
if self.should_rank_save_checkpoint:
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)
if self.is_global_zero:
self.checkpoint_io.save_checkpoint(checkpoint, filepath)

def remove_checkpoint(self, filepath: _PATH) -> None:
"""Remove checkpoint filepath from the filesystem.

Args:
filepath: Path to checkpoint
"""
if self.should_rank_save_checkpoint:
return self.checkpoint_io.remove_checkpoint(filepath)
if self.is_global_zero:
self.checkpoint_io.remove_checkpoint(filepath)

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
Expand All @@ -430,11 +430,6 @@ def teardown(self) -> None:
def register_plugins(cls, plugin_registry) -> None:
pass

@property
def should_rank_save_checkpoint(self) -> bool:
"""Returns whether the checkpoint should be saved (rank based)"""
return self.is_global_zero

def on_train_start(self) -> None:
"""Called when train begins."""
pass
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
ParallelPlugin,
PLUGIN_INPUT,
PrecisionPlugin,
TPUSpawnPlugin,
TrainingTypePlugin,
)
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
Expand Down Expand Up @@ -1702,9 +1703,10 @@ def world_size(self) -> int:
@property
def should_rank_save_checkpoint(self) -> bool:
rank_zero_deprecation(
"`Trainer.should_rank_save_checkpoint` is deprecated in v1.6 and will be removed in 1.8.", stacklevel=5
"`Trainer.should_rank_save_checkpoint` is deprecated in v1.6 and will be removed in v1.8.", stacklevel=5
)
return self.training_type_plugin.should_rank_save_checkpoint
ttp = self.training_type_plugin
return isinstance(ttp, TPUSpawnPlugin) and ttp.local_rank == 0 or ttp.is_global_zero
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@property
def _distrib_type(self) -> _StrategyType:
Expand Down
2 changes: 1 addition & 1 deletion tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,6 @@ def on_hpc_load(self):
def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir):
trainer = Trainer()
with pytest.deprecated_call(
match=r"`Trainer.should_rank_save_checkpoint` is deprecated in v1.6 and will be removed in 1.8."
match=r"`Trainer.should_rank_save_checkpoint` is deprecated in v1.6 and will be removed in v1.8."
):
_ = trainer.should_rank_save_checkpoint