Skip to content

Commit

Permalink
Mark SLURM detection methods in AcceleratorConnector as protected (#…
Browse files Browse the repository at this point in the history
…10101)



Co-authored-by: Justus Schock <[email protected]>
  • Loading branch information
awaelchli and justusschock authored Oct 25, 2021
1 parent 2ee3127 commit 76081fb
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 18 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924))


- Deprecated access to the `AcceleratorConnector.is_slurm_managing_tasks` attribute and marked it as protected ([#10101](https://github.com/PyTorchLightning/pytorch-lightning/pull/10101))


- Deprecated access to the `AcceleratorConnector.configure_slurm_ddp` method and marked it as protected ([#10101](https://github.com/PyTorchLightning/pytorch-lightning/pull/10101))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
42 changes: 31 additions & 11 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
self.precision = precision
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
self.amp_level = amp_level
self.is_slurm_managing_tasks = False
self._is_slurm_managing_tasks = False

self._precision_plugin: Optional[PrecisionPlugin] = None
self._training_type_plugin: Optional[TrainingTypePlugin] = None
Expand Down Expand Up @@ -164,7 +164,7 @@ def __init__(
self._set_training_type_plugin()
else:
self.set_distributed_mode()
self.configure_slurm_ddp()
self._configure_slurm_ddp()

self.handle_given_plugins()
self.update_device_type_if_ipu_plugin()
Expand Down Expand Up @@ -685,15 +685,15 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices
)
elif self.use_ddp:
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
use_ddp_cpu_spawn = use_ddp_spawn and self.use_cpu
use_tpu_spawn = self.use_tpu and self._distrib_type == DistributedType.TPU_SPAWN
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN
use_ddp_fully_sharded = self._distrib_type == DistributedType.DDP_FULLY_SHARDED
Expand Down Expand Up @@ -789,7 +789,7 @@ def select_accelerator(self) -> Accelerator:
def select_cluster_environment(self) -> ClusterEnvironment:
if self._cluster_environment is not None:
return self._cluster_environment
if self.is_slurm_managing_tasks:
if self._is_slurm_managing_tasks:
env = SLURMEnvironment()
elif TorchElasticEnvironment.is_using_torchelastic():
env = TorchElasticEnvironment()
Expand Down Expand Up @@ -972,7 +972,27 @@ def update_device_type_if_training_type_plugin_passed(self) -> None:
elif self.has_gpu:
self._device_type = DeviceType.GPU

def configure_slurm_ddp(self):
@property
def is_slurm_managing_tasks(self) -> bool:
rank_zero_deprecation(
"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5 and will be removed in v1.6."
)
return self._is_slurm_managing_tasks

@is_slurm_managing_tasks.setter
def is_slurm_managing_tasks(self, value: bool) -> bool:
rank_zero_deprecation(
"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5 and will be removed in v1.6."
)
self._is_slurm_managing_tasks = value

def configure_slurm_ddp(self) -> None:
rank_zero_deprecation(
"`AcceleratorConnector.configure_slurm_ddp()` was deprecated in v1.5 and will be removed in v1.6."
)
self._configure_slurm_ddp()

def _configure_slurm_ddp(self):
# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
# otherwise we launch the required number of processes
Expand All @@ -981,21 +1001,21 @@ def configure_slurm_ddp(self):
num_slurm_tasks = 0
try:
num_slurm_tasks = int(os.environ["SLURM_NTASKS"])
self.is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus
self._is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus

# enable slurm cpu
if num_requested_gpus == 0:
self.is_slurm_managing_tasks = num_slurm_tasks == self.num_processes
self._is_slurm_managing_tasks = num_slurm_tasks == self.num_processes

# in interactive mode we don't manage tasks
job_name = os.environ["SLURM_JOB_NAME"]
if job_name == "bash":
self.is_slurm_managing_tasks = False
self._is_slurm_managing_tasks = False

except Exception:
# likely not on slurm, so set the slurm managed flag to false
self.is_slurm_managing_tasks = False
self._is_slurm_managing_tasks = False

# notify user the that slurm is managing tasks
if self.is_slurm_managing_tasks:
if self._is_slurm_managing_tasks:
rank_zero_info("Multi-processing is handled by Slurm.")
12 changes: 6 additions & 6 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock):
def test_accelerator_choice_ddp_slurm(setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert trainer.accelerator_connector._is_slurm_managing_tasks
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -132,7 +132,7 @@ def on_fit_start(self, trainer, pl_module):
def test_accelerator_choice_ddp2_slurm(device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert trainer.accelerator_connector._is_slurm_managing_tasks
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -307,7 +307,7 @@ def on_fit_start(self, trainer, pl_module):
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert trainer.accelerator_connector._is_slurm_managing_tasks
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -756,7 +756,7 @@ def test_strategy_choice_ddp_spawn(cuda_available_mock, device_count_mock):
def test_strategy_choice_ddp_slurm(setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert trainer.accelerator_connector._is_slurm_managing_tasks
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -788,7 +788,7 @@ def on_fit_start(self, trainer, pl_module):
def test_strategy_choice_ddp2_slurm(device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert trainer.accelerator_connector._is_slurm_managing_tasks
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -963,7 +963,7 @@ def on_fit_start(self, trainer, pl_module):
def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert trainer.accelerator_connector._is_slurm_managing_tasks
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down
15 changes: 15 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,18 @@ def test_v1_6_0_deprecated_accelerator_pass_through_functions():

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.on_train_batch_start(batch=None, batch_idx=0)


def test_v1_6_0_configure_slurm_ddp():
trainer = Trainer()
with pytest.deprecated_call(match=r"`AcceleratorConnector.configure_slurm_ddp\(\)` was deprecated in v1.5"):
trainer.accelerator_connector.configure_slurm_ddp()


def test_v1_6_0_is_slurm_managing_tasks():
trainer = Trainer()
with pytest.deprecated_call(match=r"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5"):
_ = trainer.accelerator_connector.is_slurm_managing_tasks

with pytest.deprecated_call(match=r"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5"):
trainer.accelerator_connector.is_slurm_managing_tasks = False
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def test_dp_resume(tmpdir):

# fit model
trainer = Trainer(**trainer_options)
trainer.is_slurm_managing_tasks = True
trainer._is_slurm_managing_tasks = True
trainer.fit(model, datamodule=dm)

# track epoch before saving. Increment since we finished the current epoch, don't want to rerun
Expand Down

0 comments on commit 76081fb

Please sign in to comment.