Skip to content

Commit

Permalink
Move block_backward_sync from ParallelPlugin to DDPPlugins
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Aug 25, 2021
1 parent de57fef commit cc6284a
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 18 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052))


- Removed `block_backward_sync` from `ParallelPlugin` and added to `DDPPlugin` and `DDPSpawnPlugin` ([#9101](https://github.com/PyTorchLightning/pytorch-lightning/pull/9101))


### Fixed

- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_process_training_step_output,
check_finite_loss,
)
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
Expand Down Expand Up @@ -430,9 +430,10 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator
Returns:
context manager with sync behaviour off
"""
if isinstance(self.trainer.training_type_plugin, ParallelPlugin) and (
self.trainer.lightning_module.automatic_optimization or should_block_sync
):
if (
isinstance(self.trainer.training_type_plugin, DDPPlugin)
or isinstance(self.trainer.training_type_plugin, DDPPlugin)
) and (self.trainer.lightning_module.automatic_optimization or should_block_sync):
with self.trainer.training_type_plugin.block_backward_sync():
yield None
else:
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import tempfile
import time
from contextlib import contextmanager
from pathlib import Path
from time import sleep
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -442,3 +443,16 @@ def reconciliate_processes(self, trace: str):
os.kill(pid, signal.SIGKILL)
shutil.rmtree(sync_dir)
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")

@contextmanager
def block_backward_sync(self):
"""
Blocks ddp sync gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, DistributedDataParallel):
with self.model.no_sync():
yield None
else:
yield None
14 changes: 14 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
import os
import re
from contextlib import contextmanager
from multiprocessing.queues import SimpleQueue
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -364,3 +365,16 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
description="DDPSpawn Plugin with `find_unused_parameters` as False",
find_unused_parameters=False,
)

@contextmanager
def block_backward_sync(self):
"""
Blocks ddp sync gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, DistributedDataParallel):
with self.model.no_sync():
yield None
else:
yield None
14 changes: 0 additions & 14 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, List, Optional

import torch
Expand Down Expand Up @@ -121,19 +120,6 @@ def configure_sync_batchnorm(model: "pl.LightningModule") -> "pl.LightningModule
"""
return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

@contextmanager
def block_backward_sync(self):
"""
Blocks ddp sync gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, DistributedDataParallel):
with self.model.no_sync():
yield None
else:
yield None

def teardown(self) -> None:
# Un-reference the wrapper if any was used.
# todo (tchaton): Add support for all plugins.
Expand Down

0 comments on commit cc6284a

Please sign in to comment.