Skip to content

Commit

Permalink
Set num_nodes and sync_batchnorm From Trainer for Manually Passed…
Browse files Browse the repository at this point in the history
… Training Type Plugin (#7026)

Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
6 people authored May 8, 2021
1 parent 710b144 commit 987530c
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 44 deletions.
12 changes: 9 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,21 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))


- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351))


- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))


### Deprecated


- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422))


- Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))


### Removed


Expand Down Expand Up @@ -144,9 +153,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Ensure accelerator is valid if running interactively ([#5970](https://github.com/PyTorchLightning/pytorch-lightning/pull/5970))
- Disabled batch transfer in DP mode ([#6098](https://github.com/PyTorchLightning/pytorch-lightning/pull/6098))

- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351))


### Deprecated

- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339))
Expand Down
48 changes: 39 additions & 9 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_HYDRA_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available
Expand Down Expand Up @@ -62,23 +63,33 @@ class DDPPlugin(ParallelPlugin):
def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: int = 1,
num_nodes: Optional[int] = None,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: bool = False,
sync_batchnorm: Optional[bool] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
self.interactive_ddp_procs = []
self.num_nodes = num_nodes
self.sync_batchnorm = sync_batchnorm
if num_nodes is not None:
rank_zero_deprecation(
"Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
" Notice that it will be overriden by the trainer setting."
)
self._num_nodes = num_nodes or 1
if sync_batchnorm is not None:
rank_zero_deprecation(
"Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
" Notice that it will be overriden by the trainer setting."
)
self._sync_batchnorm = sync_batchnorm or False
self.dist = LightningDistributed()
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
self._ddp_kwargs = kwargs
self._has_spawned_children = False
self.task_idx = None
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
Expand All @@ -88,6 +99,24 @@ def __init__(
def root_device(self):
return self.parallel_devices[self.local_rank]

@property
def num_nodes(self) -> int:
return self._num_nodes

@num_nodes.setter
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
self._num_nodes = num_nodes
self.set_world_ranks()

@property
def sync_batchnorm(self) -> bool:
return self._sync_batchnorm

@sync_batchnorm.setter
def sync_batchnorm(self, sync_batchnorm: bool) -> None:
self._sync_batchnorm = sync_batchnorm

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
Expand Down Expand Up @@ -212,10 +241,11 @@ def _check_can_spawn_children(self):
)

def set_world_ranks(self) -> None:
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def distributed_sampler_kwargs(self):
def _is_single_process_single_device(self) -> bool:
return False

def set_world_ranks(self):
def set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank)
self.cluster_environment.set_world_size(self.num_nodes)
53 changes: 44 additions & 9 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import (
rank_zero_deprecation,
rank_zero_only,
rank_zero_warn,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.seed import reset_seed

if _TORCH_GREATER_EQUAL_1_8:
Expand All @@ -51,17 +57,27 @@ class DDPSpawnPlugin(ParallelPlugin):
def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: int = 1,
num_nodes: Optional[int] = None,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: bool = False,
sync_batchnorm: Optional[bool] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Any,
):
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
self.num_nodes = num_nodes
self.sync_batchnorm = sync_batchnorm
if num_nodes is not None:
rank_zero_deprecation(
"Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. "
"Notice that it will be overriden by the trainer setting."
)
self._num_nodes = num_nodes or 1
if sync_batchnorm is not None:
rank_zero_deprecation(
"Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. "
"Notice that it will be overriden by the trainer setting."
)
self._sync_batchnorm = sync_batchnorm or False
self._ddp_kwargs = kwargs
self.dist = LightningDistributed()
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
Expand All @@ -72,6 +88,24 @@ def __init__(
self._local_rank = 0
self.set_world_ranks()

@property
def num_nodes(self) -> int:
return self._num_nodes

@num_nodes.setter
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
self._num_nodes = num_nodes
self.set_world_ranks()

@property
def sync_batchnorm(self) -> bool:
return self._sync_batchnorm

@sync_batchnorm.setter
def sync_batchnorm(self, sync_batchnorm: bool) -> None:
self._sync_batchnorm = sync_batchnorm

@property
def local_rank(self) -> int:
return self._local_rank
Expand Down Expand Up @@ -106,10 +140,11 @@ def setup(self, model):

def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

@property
def mp_spawn_kwargs(self):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
logging_batch_size_per_gpu: Union[str, int] = "auto",
config: Optional[Union[Path, str, dict]] = None,
logging_level: int = logging.WARN,
num_nodes: int = 1,
num_nodes: Optional[int] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
loss_scale: float = 0,
Expand Down
27 changes: 12 additions & 15 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(

self.handle_given_plugins()

self._training_type_plugin_resolved = False
self.accelerator = self.select_accelerator()

# override dist backend when using tpus
Expand Down Expand Up @@ -222,10 +223,13 @@ def precision_plugin(self) -> PrecisionPlugin:

@property
def training_type_plugin(self) -> TrainingTypePlugin:
if self._training_type_plugin_resolved:
# avoid calling `resolve_training_type_plugin` multiple times
return self._training_type_plugin
if self._training_type_plugin is None:
self._training_type_plugin = self.select_training_type_plugin()
else:
self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin)
self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin)
self._training_type_plugin_resolved = True

return self._training_type_plugin

Expand Down Expand Up @@ -320,7 +324,6 @@ def is_using_torchelastic(self) -> bool:
"""
.. deprecated:: v1.3
Will be removed in v1.5.0.
Returns:
``True`` if the current process was launched using the torchelastic command.
"""
Expand Down Expand Up @@ -385,15 +388,11 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
if self.use_ddp2:
plugin = DDP2Plugin(
parallel_devices=self.parallel_devices,
num_nodes=self.num_nodes,
cluster_environment=self.cluster_environment,
sync_batchnorm=self.sync_batchnorm,
)
elif self.use_ddp and self.use_deepspeed:
plugin = DeepSpeedPlugin(
num_nodes=self.num_nodes,
cluster_environment=self.select_cluster_environment(),
parallel_devices=self.parallel_devices
cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices
)
elif self.use_ddp:
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
Expand Down Expand Up @@ -426,9 +425,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:

plugin = ddp_plugin_cls(
parallel_devices=self.parallel_devices,
num_nodes=self.num_nodes,
cluster_environment=self.cluster_environment,
sync_batchnorm=self.sync_batchnorm,
)
elif self.use_dp:
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
Expand All @@ -443,20 +440,20 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:

def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
# necessary for when the user has passed in a plugin
if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'):
if hasattr(training_type, 'parallel_devices') and getattr(training_type, 'parallel_devices') is None:
training_type.parallel_devices = self.parallel_devices
if hasattr(training_type, 'num_processes'):
training_type.num_processes = len(self.parallel_devices)

if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None:
training_type.cluster_environment = self.select_cluster_environment()

if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None:
if hasattr(training_type, 'num_nodes'):
# set num_nodes for training_type from trainer setting
training_type.num_nodes = self.num_nodes

# Automatically set sync_batchnorm if None.
# Useful for custom plugins.
if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None:
if hasattr(training_type, 'sync_batchnorm'):
# set sync_batchnorm for training_type from trainer setting
training_type.sync_batchnorm = self.sync_batchnorm

return training_type
Expand Down
21 changes: 21 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
from tests.helpers import BoringModel


Expand All @@ -28,3 +29,23 @@ def test_v1_6_0_trainer_model_hook_mixin(tmpdir):

with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"):
trainer.has_arg("training_step", "batch")


def test_v1_6_0_ddp_num_nodes():
with pytest.deprecated_call(match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"):
DDPPlugin(num_nodes=1)


def test_v1_6_0_ddp_sync_batchnorm():
with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"):
DDPPlugin(sync_batchnorm=False)


def test_v1_6_0_ddp_spawn_num_nodes():
with pytest.deprecated_call(match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"):
DDPSpawnPlugin(num_nodes=1)


def test_v1_6_0_ddp_spawn_sync_batchnorm():
with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"):
DDPSpawnPlugin(sync_batchnorm=False)
16 changes: 10 additions & 6 deletions tests/plugins/test_cluster_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ def environment_combinations():


@pytest.mark.parametrize(
"plugin_cls", [
"plugin_cls",
[
DDPPlugin,
DDPShardedPlugin,
DDP2Plugin,
pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)),
pytest.param(RPCSequentialPlugin, marks=RunIf(fairscale_pipe=True)),
]
],
)
def test_ranks_available_manual_plugin_selection(plugin_cls):
""" Test that the rank information is readily available after Trainer initialization. """
Expand All @@ -79,10 +80,12 @@ def test_ranks_available_manual_plugin_selection(plugin_cls):
with mock.patch.dict(os.environ, variables):
plugin = plugin_cls(
parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)],
num_nodes=num_nodes,
cluster_environment=cluster,
)
trainer = Trainer(plugins=[plugin])
trainer = Trainer(
plugins=[plugin],
num_nodes=num_nodes,
)
assert rank_zero_only.rank == expected["global_rank"]
assert trainer.global_rank == expected["global_rank"]
assert trainer.local_rank == expected["local_rank"]
Expand All @@ -91,13 +94,14 @@ def test_ranks_available_manual_plugin_selection(plugin_cls):


@pytest.mark.parametrize(
"trainer_kwargs", [
"trainer_kwargs",
[
dict(accelerator="ddp", gpus=[1, 2]),
dict(accelerator="ddp_sharded", gpus=[1, 2]),
dict(accelerator="ddp2", gpus=[1, 2]),
dict(accelerator="ddp_cpu", num_processes=2),
dict(accelerator="ddp_spawn", gpus=[1, 2]),
]
],
)
@mock.patch("torch.cuda.is_available", return_value=True)
@mock.patch("torch.cuda.device_count", return_value=4)
Expand Down

0 comments on commit 987530c

Please sign in to comment.