diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b12b37fdb8fb..c4cf25d533dea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,12 +22,21 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) +- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351)) + + +- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026)) + + ### Deprecated - Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422)) +- Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026)) + + ### Removed @@ -144,9 +153,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Ensure accelerator is valid if running interactively ([#5970](https://github.com/PyTorchLightning/pytorch-lightning/pull/5970)) - Disabled batch transfer in DP mode ([#6098](https://github.com/PyTorchLightning/pytorch-lightning/pull/6098)) -- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351)) - - ### Deprecated - Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339)) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 63f019dfe8048..89e714d57f870 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -33,6 +33,7 @@ _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, + rank_zero_deprecation, rank_zero_warn, ) from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available @@ -62,9 +63,9 @@ class DDPPlugin(ParallelPlugin): def __init__( self, parallel_devices: Optional[List[torch.device]] = None, - num_nodes: int = 1, + num_nodes: Optional[int] = None, cluster_environment: ClusterEnvironment = None, - sync_batchnorm: bool = False, + sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, @@ -72,13 +73,23 @@ def __init__( ) -> None: super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) self.interactive_ddp_procs = [] - self.num_nodes = num_nodes - self.sync_batchnorm = sync_batchnorm + if num_nodes is not None: + rank_zero_deprecation( + "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6." + " Notice that it will be overriden by the trainer setting." + ) + self._num_nodes = num_nodes or 1 + if sync_batchnorm is not None: + rank_zero_deprecation( + "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6." + " Notice that it will be overriden by the trainer setting." + ) + self._sync_batchnorm = sync_batchnorm or False self.dist = LightningDistributed() + self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 self._ddp_kwargs = kwargs self._has_spawned_children = False self.task_idx = None - self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper @@ -88,6 +99,24 @@ def __init__( def root_device(self): return self.parallel_devices[self.local_rank] + @property + def num_nodes(self) -> int: + return self._num_nodes + + @num_nodes.setter + def num_nodes(self, num_nodes: int) -> None: + # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks + self._num_nodes = num_nodes + self.set_world_ranks() + + @property + def sync_batchnorm(self) -> bool: + return self._sync_batchnorm + + @sync_batchnorm.setter + def sync_batchnorm(self, sync_batchnorm: bool) -> None: + self._sync_batchnorm = sync_batchnorm + @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) @@ -212,10 +241,11 @@ def _check_can_spawn_children(self): ) def set_world_ranks(self) -> None: - if self.cluster_environment is not None: - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - rank_zero_only.rank = self.cluster_environment.global_rank() + if self.cluster_environment is None: + return + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index fe01835b4641d..b6d21904d1933 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -72,6 +72,8 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self) -> bool: return False - def set_world_ranks(self): + def set_world_ranks(self) -> None: + if self.cluster_environment is None: + return self.cluster_environment.set_global_rank(self.node_rank) self.cluster_environment.set_world_size(self.num_nodes) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index deea818db9f75..df9f0ee158ba3 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -31,7 +31,13 @@ from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import ( + rank_zero_deprecation, + rank_zero_only, + rank_zero_warn, + ReduceOp, + sync_ddp_if_available, +) from pytorch_lightning.utilities.seed import reset_seed if _TORCH_GREATER_EQUAL_1_8: @@ -51,17 +57,27 @@ class DDPSpawnPlugin(ParallelPlugin): def __init__( self, parallel_devices: Optional[List[torch.device]] = None, - num_nodes: int = 1, + num_nodes: Optional[int] = None, cluster_environment: ClusterEnvironment = None, - sync_batchnorm: bool = False, + sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, **kwargs: Any, ): super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) - self.num_nodes = num_nodes - self.sync_batchnorm = sync_batchnorm + if num_nodes is not None: + rank_zero_deprecation( + "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. " + "Notice that it will be overriden by the trainer setting." + ) + self._num_nodes = num_nodes or 1 + if sync_batchnorm is not None: + rank_zero_deprecation( + "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. " + "Notice that it will be overriden by the trainer setting." + ) + self._sync_batchnorm = sync_batchnorm or False self._ddp_kwargs = kwargs self.dist = LightningDistributed() self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 @@ -72,6 +88,24 @@ def __init__( self._local_rank = 0 self.set_world_ranks() + @property + def num_nodes(self) -> int: + return self._num_nodes + + @num_nodes.setter + def num_nodes(self, num_nodes: int) -> None: + # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks + self._num_nodes = num_nodes + self.set_world_ranks() + + @property + def sync_batchnorm(self) -> bool: + return self._sync_batchnorm + + @sync_batchnorm.setter + def sync_batchnorm(self, sync_batchnorm: bool) -> None: + self._sync_batchnorm = sync_batchnorm + @property def local_rank(self) -> int: return self._local_rank @@ -106,10 +140,11 @@ def setup(self, model): def set_world_ranks(self, process_idx: int = 0) -> None: self._local_rank = process_idx - if self.cluster_environment is not None: - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - rank_zero_only.rank = self.cluster_environment.global_rank() + if self.cluster_environment is None: + return + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() @property def mp_spawn_kwargs(self): diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index fe3f51fa99390..8dd04aafa6b86 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -91,7 +91,7 @@ def __init__( logging_batch_size_per_gpu: Union[str, int] = "auto", config: Optional[Union[Path, str, dict]] = None, logging_level: int = logging.WARN, - num_nodes: int = 1, + num_nodes: Optional[int] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, loss_scale: float = 0, diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index a8a72c1831600..8f25458922ffe 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -133,6 +133,7 @@ def __init__( self.handle_given_plugins() + self._training_type_plugin_resolved = False self.accelerator = self.select_accelerator() # override dist backend when using tpus @@ -222,10 +223,13 @@ def precision_plugin(self) -> PrecisionPlugin: @property def training_type_plugin(self) -> TrainingTypePlugin: + if self._training_type_plugin_resolved: + # avoid calling `resolve_training_type_plugin` multiple times + return self._training_type_plugin if self._training_type_plugin is None: self._training_type_plugin = self.select_training_type_plugin() - else: - self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin) + self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin) + self._training_type_plugin_resolved = True return self._training_type_plugin @@ -320,7 +324,6 @@ def is_using_torchelastic(self) -> bool: """ .. deprecated:: v1.3 Will be removed in v1.5.0. - Returns: ``True`` if the current process was launched using the torchelastic command. """ @@ -385,15 +388,11 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: if self.use_ddp2: plugin = DDP2Plugin( parallel_devices=self.parallel_devices, - num_nodes=self.num_nodes, cluster_environment=self.cluster_environment, - sync_batchnorm=self.sync_batchnorm, ) elif self.use_ddp and self.use_deepspeed: plugin = DeepSpeedPlugin( - num_nodes=self.num_nodes, - cluster_environment=self.select_cluster_environment(), - parallel_devices=self.parallel_devices + cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices ) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks @@ -426,9 +425,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: plugin = ddp_plugin_cls( parallel_devices=self.parallel_devices, - num_nodes=self.num_nodes, cluster_environment=self.cluster_environment, - sync_batchnorm=self.sync_batchnorm, ) elif self.use_dp: plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) @@ -443,7 +440,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin: # necessary for when the user has passed in a plugin - if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'): + if hasattr(training_type, 'parallel_devices') and getattr(training_type, 'parallel_devices') is None: training_type.parallel_devices = self.parallel_devices if hasattr(training_type, 'num_processes'): training_type.num_processes = len(self.parallel_devices) @@ -451,12 +448,12 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None: training_type.cluster_environment = self.select_cluster_environment() - if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: + if hasattr(training_type, 'num_nodes'): + # set num_nodes for training_type from trainer setting training_type.num_nodes = self.num_nodes - # Automatically set sync_batchnorm if None. - # Useful for custom plugins. - if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None: + if hasattr(training_type, 'sync_batchnorm'): + # set sync_batchnorm for training_type from trainer setting training_type.sync_batchnorm = self.sync_batchnorm return training_type diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 09312a4c41963..6949175d7df14 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -16,6 +16,7 @@ import pytest from pytorch_lightning import Trainer +from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin from tests.helpers import BoringModel @@ -28,3 +29,23 @@ def test_v1_6_0_trainer_model_hook_mixin(tmpdir): with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"): trainer.has_arg("training_step", "batch") + + +def test_v1_6_0_ddp_num_nodes(): + with pytest.deprecated_call(match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"): + DDPPlugin(num_nodes=1) + + +def test_v1_6_0_ddp_sync_batchnorm(): + with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"): + DDPPlugin(sync_batchnorm=False) + + +def test_v1_6_0_ddp_spawn_num_nodes(): + with pytest.deprecated_call(match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"): + DDPSpawnPlugin(num_nodes=1) + + +def test_v1_6_0_ddp_spawn_sync_batchnorm(): + with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"): + DDPSpawnPlugin(sync_batchnorm=False) diff --git a/tests/plugins/test_cluster_integration.py b/tests/plugins/test_cluster_integration.py index fda6b67ad5b84..f9ca8c23d34d9 100644 --- a/tests/plugins/test_cluster_integration.py +++ b/tests/plugins/test_cluster_integration.py @@ -60,13 +60,14 @@ def environment_combinations(): @pytest.mark.parametrize( - "plugin_cls", [ + "plugin_cls", + [ DDPPlugin, DDPShardedPlugin, DDP2Plugin, pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)), pytest.param(RPCSequentialPlugin, marks=RunIf(fairscale_pipe=True)), - ] + ], ) def test_ranks_available_manual_plugin_selection(plugin_cls): """ Test that the rank information is readily available after Trainer initialization. """ @@ -79,10 +80,12 @@ def test_ranks_available_manual_plugin_selection(plugin_cls): with mock.patch.dict(os.environ, variables): plugin = plugin_cls( parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)], - num_nodes=num_nodes, cluster_environment=cluster, ) - trainer = Trainer(plugins=[plugin]) + trainer = Trainer( + plugins=[plugin], + num_nodes=num_nodes, + ) assert rank_zero_only.rank == expected["global_rank"] assert trainer.global_rank == expected["global_rank"] assert trainer.local_rank == expected["local_rank"] @@ -91,13 +94,14 @@ def test_ranks_available_manual_plugin_selection(plugin_cls): @pytest.mark.parametrize( - "trainer_kwargs", [ + "trainer_kwargs", + [ dict(accelerator="ddp", gpus=[1, 2]), dict(accelerator="ddp_sharded", gpus=[1, 2]), dict(accelerator="ddp2", gpus=[1, 2]), dict(accelerator="ddp_cpu", num_processes=2), dict(accelerator="ddp_spawn", gpus=[1, 2]), - ] + ], ) @mock.patch("torch.cuda.is_available", return_value=True) @mock.patch("torch.cuda.device_count", return_value=4)