From 97a81c3cfed9d6677d411672eecb0ed38514cb04 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Sat, 20 Feb 2021 12:58:54 +0000 Subject: [PATCH] [Hot Fix] Give priority to plugins to set distributed mode, and then accelerator (#6089) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Give priority to plugins to set distributed mode, and then accelerator * Add CHANGELOG.md * Update CHANGELOG.md * Remove very scary line * Ensure we set cluster environment after slurm configured if necessary * Simplify the fix with a reset Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 2 ++ .../connectors/accelerator_connector.py | 4 +++- .../test_accelerator_connector.py | 24 ++++++++++++++++++- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 66a64d52195a7..52cdb000a2a0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011) +- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089)) + ## [1.2.0] - 2021-02-18 diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index d32970d61fa9b..7021081d6cc90 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -163,6 +163,9 @@ def handle_given_plugins( for plug in plugins: if isinstance(plug, str): + # Reset the distributed type as the user has overridden training type + # via the plugins argument + self._distrib_type = None self.set_distributed_mode(plug) elif isinstance(plug, TrainingTypePlugin): @@ -196,7 +199,6 @@ def handle_given_plugins( ) self._training_type_plugin = training_type - self._training_type_plugin = self.training_type_plugin self._precision_plugin = precision self._cluster_environment = cluster_environment or self.select_cluster_environment() diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 76d4a597d8ecb..82b631807c8e9 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -23,7 +23,14 @@ from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPSpawnPlugin, PrecisionPlugin, SingleDevicePlugin +from pytorch_lightning.plugins import ( + DDP2Plugin, + DDPPlugin, + DDPShardedPlugin, + DDPSpawnPlugin, + PrecisionPlugin, + SingleDevicePlugin, +) from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment from tests.helpers.boring_model import BoringModel @@ -378,3 +385,18 @@ def on_fit_start(self, trainer, pl_module): with pytest.raises(SystemExit): trainer.fit(model) + + +@pytest.mark.parametrize( + ["accelerator", "plugin"], + [('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')], +) +def test_plugin_accelerator_choice(accelerator, plugin): + """ + Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent. + """ + trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2) + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) + + trainer = Trainer(plugins=plugin, num_processes=2) + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)