diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d7e99d7163fb..34106daedefaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -205,6 +205,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052)) +- Removed `block_backward_sync` from `ParallelPlugin` and added to `DDPPlugin` and `DDPSpawnPlugin` ([#9101](https://github.com/PyTorchLightning/pytorch-lightning/pull/9101)) + + ### Fixed - Fixed save/load/resume from checkpoint for DeepSpeed Plugin ( diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 29517ad306eba..3b5a2c097b4e9 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -31,7 +31,7 @@ _process_training_step_output, check_finite_loss, ) -from pytorch_lightning.plugins import ParallelPlugin +from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm @@ -430,9 +430,10 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator Returns: context manager with sync behaviour off """ - if isinstance(self.trainer.training_type_plugin, ParallelPlugin) and ( - self.trainer.lightning_module.automatic_optimization or should_block_sync - ): + if ( + isinstance(self.trainer.training_type_plugin, DDPPlugin) + or isinstance(self.trainer.training_type_plugin, DDPPlugin) + ) and (self.trainer.lightning_module.automatic_optimization or should_block_sync): with self.trainer.training_type_plugin.block_backward_sync(): yield None else: diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 787353be307e6..3d72fa66b48e3 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -19,6 +19,7 @@ import sys import tempfile import time +from contextlib import contextmanager from pathlib import Path from time import sleep from typing import Any, Dict, List, Optional, Union @@ -442,3 +443,16 @@ def reconciliate_processes(self, trace: str): os.kill(pid, signal.SIGKILL) shutil.rmtree(sync_dir) raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}") + + @contextmanager + def block_backward_sync(self): + """ + Blocks ddp sync gradients behaviour on backwards pass. + This is useful for skipping sync when accumulating gradients, reducing communication overhead + Returns: context manager with sync behaviour off + """ + if isinstance(self.model, DistributedDataParallel): + with self.model.no_sync(): + yield None + else: + yield None diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 08c049997bdfd..c7c71a87545a2 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -14,6 +14,7 @@ import logging import os import re +from contextlib import contextmanager from multiprocessing.queues import SimpleQueue from typing import Any, Dict, List, Optional, Union @@ -364,3 +365,16 @@ def register_plugins(cls, plugin_registry: Dict) -> None: description="DDPSpawn Plugin with `find_unused_parameters` as False", find_unused_parameters=False, ) + + @contextmanager + def block_backward_sync(self): + """ + Blocks ddp sync gradients behaviour on backwards pass. + This is useful for skipping sync when accumulating gradients, reducing communication overhead + Returns: context manager with sync behaviour off + """ + if isinstance(self.model, DistributedDataParallel): + with self.model.no_sync(): + yield None + else: + yield None diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 71aae1bb71a91..8d04882a5ff18 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -13,7 +13,6 @@ # limitations under the License. import os from abc import ABC, abstractmethod -from contextlib import contextmanager from typing import Any, List, Optional import torch @@ -121,19 +120,6 @@ def configure_sync_batchnorm(model: "pl.LightningModule") -> "pl.LightningModule """ return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - @contextmanager - def block_backward_sync(self): - """ - Blocks ddp sync gradients behaviour on backwards pass. - This is useful for skipping sync when accumulating gradients, reducing communication overhead - Returns: context manager with sync behaviour off - """ - if isinstance(self.model, DistributedDataParallel): - with self.model.no_sync(): - yield None - else: - yield None - def teardown(self) -> None: # Un-reference the wrapper if any was used. # todo (tchaton): Add support for all plugins.