Skip to content

Commit

Permalink
Merge branch 'master' into refactor/8634_duplicate_GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored Aug 3, 2021
2 parents 0de5b6f + e5d9e21 commit 7d3a525
Show file tree
Hide file tree
Showing 28 changed files with 884 additions and 216 deletions.
2 changes: 1 addition & 1 deletion .azure-pipelines/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
- bash: |
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
pip install fairscale>=0.3.4
pip install "deepspeed>=0.4.0, !=0.4.4" # FIXME: bug with 0.4.4
pip install "deepspeed>=0.4.3, !=0.4.4" # FIXME: bug with 0.4.4
pip install . --requirement requirements/devel.txt
pip list
displayName: 'Install dependencies'
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/events-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ env:
# based on https://github.com/pypa/gh-action-pypi-publish
jobs:
pypi-release:
if: ${{ github.repository_owner == 'PyTorchLightning' }}
runs-on: ubuntu-20.04

steps:
Expand Down Expand Up @@ -47,12 +48,14 @@ jobs:
verbose: true

docker-XLA:
if: ${{ github.repository_owner == 'PyTorchLightning' }}
runs-on: ubuntu-20.04
strategy:
fail-fast: false
matrix:
python_version: ["3.7"]
xla_version: ["1.6", "1.7", "1.8", "1.9"] # todo: , "nightly"

steps:
- name: Checkout
uses: actions/checkout@v2
Expand All @@ -79,6 +82,7 @@ jobs:
timeout-minutes: 55

docker-CUDA:
if: ${{ github.repository_owner == 'PyTorchLightning' }}
runs-on: ubuntu-20.04
strategy:
fail-fast: false
Expand Down Expand Up @@ -113,6 +117,7 @@ jobs:
timeout-minutes: 55

docker-Conda:
if: ${{ github.repository_owner == 'PyTorchLightning' }}
runs-on: ubuntu-20.04
strategy:
fail-fast: false
Expand Down Expand Up @@ -154,6 +159,7 @@ jobs:
timeout-minutes: 55

docker-IPU:
if: ${{ github.repository_owner == 'PyTorchLightning' }}
runs-on: ubuntu-20.04
strategy:
fail-fast: false
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/events-recurrent.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ env:
jobs:
tpu-cleanup:
name: TPU cleaning
if: ${{ github.repository_owner == 'PyTorchLightning' }}
runs-on: ubuntu-20.04

steps:
Expand Down
16 changes: 13 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added check for unique GPU ids ([#8666](https://github.com/PyTorchLightning/pytorch-lightning/pull/8666))


- Added `ResultCollection` state_dict to Loop `state_dict` and support for distributed reload. ([#8641](https://github.com/PyTorchLightning/pytorch-lightning/pull/8641))


-


-


- Fault-tolerant training:
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))


### Changed

- Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477))
Expand All @@ -28,7 +36,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352)))



- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


Expand Down Expand Up @@ -89,7 +96,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `training_step` outputs not getting collected correctly for `training_epoch_end` ([#8613](https://github.com/PyTorchLightning/pytorch-lightning/pull/8613))


-
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
[#8397](https://github.com/PyTorchLightning/pytorch-lightning/pull/8397),
[#8644](https://github.com/PyTorchLightning/pytorch-lightning/pull/8644),
[#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627))


## [1.4.0] - 2021-07-27
Expand Down Expand Up @@ -133,6 +143,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fault-tolerant training
* Added `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948))
* Added `{,load_}state_dict` to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197))
* Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307))
* Set `Loop.restarting=False` at the end of the first iteration ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Save the loops state with the checkpoint (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Save a checkpoint to restore the state on exception (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
Expand All @@ -155,7 +166,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `max_depth` parameter in `ModelSummary` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062))
- Added `XLAStatsMonitor` callback ([#8235](https://github.com/PyTorchLightning/pytorch-lightning/pull/8235))
- Added `restore` function and `restarting` attribute to base `Loop` ([#8247](https://github.com/PyTorchLightning/pytorch-lightning/pull/8247))
- Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307))
- Added support for `save_hyperparameters` in `LightningDataModule` ([#3792](https://github.com/PyTorchLightning/pytorch-lightning/pull/3792))
- Added the `ModelCheckpoint(save_on_train_epoch_end)` to choose when to run the saving logic ([#8389](https://github.com/PyTorchLightning/pytorch-lightning/pull/8389))
- Added `LSFEnvironment` for distributed training with the LSF resource manager `jsrun` ([#5102](https://github.com/PyTorchLightning/pytorch-lightning/pull/5102))
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ order_by_type = "False"

[tool.black]
line-length = 120
skip-magic-trailing-comma = true


[tool.mypy]
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def period(self, value: Optional[int]) -> None:

def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None:
if trainer.should_rank_save_checkpoint and self._fs.exists(filepath):
self._fs.rm(filepath)
self._fs.rm(filepath, recursive=True)
log.debug(f"Removed checkpoint: {filepath}")

def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None:
Expand Down
55 changes: 49 additions & 6 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from typing import Any, Dict, Optional

from deprecate import void
from torchmetrics import Metric

import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BaseProgress, Progress
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -173,25 +175,66 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] =
destination[prefix + "state_dict"] = self.on_save_checkpoint()

for k, v in self.__dict__.items():
key = prefix + k
if isinstance(v, BaseProgress):
destination[prefix + k] = v.state_dict()
destination[key] = v.state_dict()
elif isinstance(v, Loop):
v.state_dict(destination, prefix + k + ".")
v.state_dict(destination, key + ".")
elif isinstance(v, ResultCollection):
# sync / unsync metrics
v.sync()
destination[key] = v.state_dict()
v.unsync()

return destination

def load_state_dict(self, state_dict: Dict, prefix: str = "", restart_progress: bool = True) -> None:
def load_state_dict(
self,
state_dict: Dict,
prefix: str = "",
restart_progress: bool = True,
metrics: Optional[Dict[str, Metric]] = None,
) -> None:
"""Loads the state of this loop and all its children."""
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress)
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics)
for k, v in self.__dict__.items():
if isinstance(v, Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress)

def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress: bool) -> None:
def _load_from_state_dict(
self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None
) -> None:
for k, v in self.__dict__.items():
key = prefix + k
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[prefix + k])
v.load_state_dict(state_dict[key])
if restart_progress:
apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart())

elif (
isinstance(v, ResultCollection)
and self.trainer is not None
and getattr(self.trainer, "lightning_module", None) is not None
):
metric_attributes = {
name: module
for name, module in self.trainer.lightning_module.named_modules()
if isinstance(module, Metric)
}
if metrics:
metric_attributes.update(metrics)

# The `ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`.
# When creating a checkpoint, the `Metric`s are dropped from the loop `state_dict` to serialize only
# Python primitives. However, their states are saved with the model's `state_dict`.
# On reload, we need to re-attach the `Metric`s back to the `ResultCollection`.
# The references are provided through the `metric_attributes` dictionary.
v.load_state_dict(
state_dict[prefix + k], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce
)

if not self.trainer.is_global_zero:
v.reset(metrics=False)

self.on_load_checkpoint(state_dict[prefix + "state_dict"])
self.restarting = True
23 changes: 15 additions & 8 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,6 @@ def _call_children_scripts(self):
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())

# create a temporary directory used to synchronize processes on deadlock.
os.environ["PL_DDP_SYNC_TMPDIR"] = self._sync_dir = tempfile.mkdtemp()

# Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c`
# See https://docs.python.org/3/reference/import.html#main-spec
if __main__.__spec__ is None: # pragma: no-cover
Expand Down Expand Up @@ -410,8 +407,18 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
def _share_information_to_prevent_deadlock(self):
self._share_pids()

# remove `PL_DDP_SYNC_TMPDIR` from os.environ
self._sync_dir = os.environ.pop("PL_DDP_SYNC_TMPDIR", None)
# there should be a unique sync_dir per nodes.
if self.local_rank == 0:
# create a temporary directory used to synchronize processes on deadlock.
self._sync_dir = tempfile.mkdtemp()

sync_dirs = []
global_node_rank_zero = 0
for _ in range(self.num_nodes):
sync_dirs.append(self.broadcast(self._sync_dir, global_node_rank_zero))
global_node_rank_zero += self.world_size // self.num_nodes

self._sync_dir = sync_dirs[self.node_rank]

def _share_pids(self):
"""
Expand All @@ -436,11 +443,11 @@ def reconciliate_processes(self, trace: str):

# return if all processes wrote a file in the `sync_dir`.
# todo (tchaton) Add support for non-shared file-system which will fail.
if len(os.listdir(sync_dir)) == self.world_size:
if len(os.listdir(sync_dir)) == (self.world_size // self.num_nodes):
return

for pid in self._pids:
if pid != os.getpid():
os.kill(pid, signal.SIGKILL)
shutil.rmtree(sync_dir)
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
shutil.rmtree(sync_dir)
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
Loading

0 comments on commit 7d3a525

Please sign in to comment.