diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 29517ad306ebae..3b5a2c097b4e9e 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -31,7 +31,7 @@ _process_training_step_output, check_finite_loss, ) -from pytorch_lightning.plugins import ParallelPlugin +from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm @@ -430,9 +430,10 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator Returns: context manager with sync behaviour off """ - if isinstance(self.trainer.training_type_plugin, ParallelPlugin) and ( - self.trainer.lightning_module.automatic_optimization or should_block_sync - ): + if ( + isinstance(self.trainer.training_type_plugin, DDPPlugin) + or isinstance(self.trainer.training_type_plugin, DDPPlugin) + ) and (self.trainer.lightning_module.automatic_optimization or should_block_sync): with self.trainer.training_type_plugin.block_backward_sync(): yield None else: diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 787353be307e69..3d72fa66b48e3a 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -19,6 +19,7 @@ import sys import tempfile import time +from contextlib import contextmanager from pathlib import Path from time import sleep from typing import Any, Dict, List, Optional, Union @@ -442,3 +443,16 @@ def reconciliate_processes(self, trace: str): os.kill(pid, signal.SIGKILL) shutil.rmtree(sync_dir) raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}") + + @contextmanager + def block_backward_sync(self): + """ + Blocks ddp sync gradients behaviour on backwards pass. + This is useful for skipping sync when accumulating gradients, reducing communication overhead + Returns: context manager with sync behaviour off + """ + if isinstance(self.model, DistributedDataParallel): + with self.model.no_sync(): + yield None + else: + yield None diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 08c049997bdfda..c7c71a87545a29 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -14,6 +14,7 @@ import logging import os import re +from contextlib import contextmanager from multiprocessing.queues import SimpleQueue from typing import Any, Dict, List, Optional, Union @@ -364,3 +365,16 @@ def register_plugins(cls, plugin_registry: Dict) -> None: description="DDPSpawn Plugin with `find_unused_parameters` as False", find_unused_parameters=False, ) + + @contextmanager + def block_backward_sync(self): + """ + Blocks ddp sync gradients behaviour on backwards pass. + This is useful for skipping sync when accumulating gradients, reducing communication overhead + Returns: context manager with sync behaviour off + """ + if isinstance(self.model, DistributedDataParallel): + with self.model.no_sync(): + yield None + else: + yield None diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 71aae1bb71a918..8d04882a5ff182 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -13,7 +13,6 @@ # limitations under the License. import os from abc import ABC, abstractmethod -from contextlib import contextmanager from typing import Any, List, Optional import torch @@ -121,19 +120,6 @@ def configure_sync_batchnorm(model: "pl.LightningModule") -> "pl.LightningModule """ return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - @contextmanager - def block_backward_sync(self): - """ - Blocks ddp sync gradients behaviour on backwards pass. - This is useful for skipping sync when accumulating gradients, reducing communication overhead - Returns: context manager with sync behaviour off - """ - if isinstance(self.model, DistributedDataParallel): - with self.model.no_sync(): - yield None - else: - yield None - def teardown(self) -> None: # Un-reference the wrapper if any was used. # todo (tchaton): Add support for all plugins. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 19ccf3935a168f..d017251e1dafc5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -32,7 +32,7 @@ from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop -from pytorch_lightning.plugins import DDPSpawnPlugin, Plugin +from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin, Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.profiler import ( AdvancedProfiler,