Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set num_nodes and sync_batchnorm From Trainer for Manually Passed Training Type Plugin #7026

Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
89f284d
Fix some test errors
Mar 23, 2021
80cfbff
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 23, 2021
536c132
checkpoint consolidation
Mar 24, 2021
f172101
Update ddp_spawn.py
shuyingsunshine21 Mar 24, 2021
bf70e43
Update test_metric_result_integration.py
shuyingsunshine21 Mar 24, 2021
ea74906
Update test_results.py
shuyingsunshine21 Mar 24, 2021
a9aae99
Update utils.py
shuyingsunshine21 Mar 24, 2021
70fe5da
Update utils.py
shuyingsunshine21 Mar 24, 2021
0d23d75
Update test_all_gather_grad.py
shuyingsunshine21 Mar 24, 2021
ca6f98b
Update test_all_gather_grad.py
shuyingsunshine21 Mar 24, 2021
c5053da
Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkp…
shuyingsunshine21 Mar 24, 2021
9d4a2b8
Update test_results.py
shuyingsunshine21 Mar 24, 2021
7635b4f
Revert "Update test_results.py"
shuyingsunshine21 Mar 24, 2021
d64f90c
Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine2…
shuyingsunshine21 Mar 24, 2021
dcdcd29
Revert "Update test_all_gather_grad.py"
shuyingsunshine21 Mar 24, 2021
8651d54
Revert "Update utils.py"
shuyingsunshine21 Mar 24, 2021
15f4b9e
Revert "Update utils.py"
shuyingsunshine21 Mar 24, 2021
250d0aa
Revert "Update test_results.py"
shuyingsunshine21 Mar 24, 2021
6c095b2
Revert "Update test_metric_result_integration.py"
shuyingsunshine21 Mar 24, 2021
8222dc9
Revert "Update ddp_spawn.py"
shuyingsunshine21 Mar 24, 2021
3a9fde9
Revert "checkpoint consolidation"
shuyingsunshine21 Mar 24, 2021
7a369f4
Revert "Revert "checkpoint consolidation""
shuyingsunshine21 Mar 24, 2021
b4a0b9e
Revert "Revert "Revert "checkpoint consolidation"""
shuyingsunshine21 Mar 24, 2021
5cf1db1
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 24, 2021
0ce7e05
Revert "Revert "Update ddp_spawn.py""
shuyingsunshine21 Mar 24, 2021
fe9736d
Revert "Revert "Update test_metric_result_integration.py""
shuyingsunshine21 Mar 24, 2021
c314ef6
Revert "Revert "Update test_results.py""
shuyingsunshine21 Mar 24, 2021
c3feda0
Revert "Revert "Update utils.py""
shuyingsunshine21 Mar 24, 2021
c759477
Revert "Revert "Update test_all_gather_grad.py""
shuyingsunshine21 Mar 24, 2021
7a8e540
Merge branch 'master' of https://github.com/shuyingsunshine21/pytorch…
Mar 24, 2021
ab8b849
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 24, 2021
4e67db2
modify distributed environment to make test pass
Mar 24, 2021
67b6188
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 25, 2021
f9afa07
rebase to upstream master
Apr 8, 2021
f337156
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 14, 2021
fffecb8
rfc
Apr 15, 2021
a74e712
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 15, 2021
089e566
rebase
Apr 15, 2021
bb8ed77
formatting
Apr 15, 2021
6b7fe6f
more nits
Apr 15, 2021
90fa8e0
nit
Apr 15, 2021
ba4f9c4
split, setting num_nodes and sync batchnorm only
Apr 15, 2021
1eed6c9
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 15, 2021
7c88c70
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 15, 2021
bdb66ab
fix test
Apr 15, 2021
552f445
add changlog
Apr 15, 2021
1655f1e
retrigger checkes
Apr 16, 2021
76853ef
Merge branch 'master' into training_type_plugin_consolidate
tchaton Apr 19, 2021
ad77ad4
comments
Apr 20, 2021
de24614
rebase
Apr 20, 2021
c9ded5b
rebase
Apr 20, 2021
77ef90a
change accelerator_connector training_type_plugin to resolve only once
Apr 20, 2021
36427ca
nits
Apr 20, 2021
eae6dc7
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 21, 2021
824fb25
make num_nodes and sync_batchnorm as optional argument for plugin and…
Apr 21, 2021
66fab62
format
Apr 21, 2021
63e4a4e
change warn to deprecation
Apr 21, 2021
2b8c772
fix
Apr 21, 2021
4feded8
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 22, 2021
6aa1cf1
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 23, 2021
76016de
minor
Apr 23, 2021
0996a5d
remove unnecessary assert
Apr 23, 2021
afa3bbd
Merge branch 'master' into training_type_plugin_consolidate
kaushikb11 Apr 26, 2021
c3b63a2
rebase
May 4, 2021
16858be
comments
May 4, 2021
e8a110b
pull rebase
May 4, 2021
60580be
remove extra in change.md
May 4, 2021
20d59a4
correct in change.md
May 4, 2021
0ab7147
fix test and flake8
May 4, 2021
9381117
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
May 4, 2021
a35fdc3
Merge branch 'master' into training_type_plugin_consolidate
carmocca May 4, 2021
9fdde94
pre-commit
carmocca May 4, 2021
6680b0d
Merge branch 'master' into training_type_plugin_consolidate
awaelchli May 8, 2021
621bfc8
whitespace standardization
awaelchli May 8, 2021
29f720b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/))


- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))


- `pl.seed_everyting` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024))


Expand Down
24 changes: 18 additions & 6 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ def __init__(
) -> None:
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
self.interactive_ddp_procs = []
self.num_nodes = num_nodes
self._num_nodes = num_nodes
self.sync_batchnorm = sync_batchnorm
self.dist = LightningDistributed()
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
self._ddp_kwargs = kwargs
self._has_spawned_children = False
self.task_idx = None
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
Expand All @@ -88,6 +88,17 @@ def __init__(
def root_device(self):
return self.parallel_devices[self.local_rank]

@property
def num_nodes(self) -> int:
return self._num_nodes

@num_nodes.setter
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting these parameters,
# need to reset world ranks
self._num_nodes = num_nodes
self.set_world_ranks()

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
Expand Down Expand Up @@ -215,10 +226,11 @@ def _check_can_spawn_children(self):
)

def set_world_ranks(self) -> None:
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def distributed_sampler_kwargs(self):
def _is_single_process_single_device(self) -> bool:
return False

def set_world_ranks(self):
def set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank)
self.cluster_environment.set_world_size(self.num_nodes)
22 changes: 17 additions & 5 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
**kwargs: Any,
):
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
self.num_nodes = num_nodes
self._num_nodes = num_nodes
self.sync_batchnorm = sync_batchnorm
self._ddp_kwargs = kwargs
self.dist = LightningDistributed()
Expand All @@ -72,6 +72,17 @@ def __init__(
self._local_rank = 0
self.set_world_ranks()

@property
def num_nodes(self) -> int:
return self._num_nodes

@num_nodes.setter
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting these parameters,
# need to reset world ranks
self._num_nodes = num_nodes
self.set_world_ranks()

@property
def local_rank(self) -> int:
return self._local_rank
Expand Down Expand Up @@ -106,10 +117,11 @@ def setup(self, model):

def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

@property
def mp_spawn_kwargs(self):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self,
rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: Optional[int] = None,
num_nodes: int = 1,
cluster_environment: Optional[ClusterEnvironment] = None,
sync_batchnorm: Optional[bool] = None,
**kwargs
Expand Down
16 changes: 10 additions & 6 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(

self.handle_given_plugins(plugins)

self._training_type_plugin_resolved = False
self.accelerator = self.select_accelerator()

# override dist backend when using tpus
Expand Down Expand Up @@ -221,10 +222,14 @@ def precision_plugin(self) -> PrecisionPlugin:

@property
def training_type_plugin(self) -> TrainingTypePlugin:
if self._training_type_plugin_resolved:
# avoid calling `resolve_training_type_plugin` multiple times
return self._training_type_plugin
if self._training_type_plugin is None:
self._training_type_plugin = self.select_training_type_plugin()
else:
self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin)
self._training_type_plugin_resolved = True

return self._training_type_plugin

Expand Down Expand Up @@ -315,7 +320,6 @@ def is_using_torchelastic(self) -> bool:
"""
.. deprecated:: v1.3
Will be removed in v1.5.0.

Returns:
``True`` if the current process was launched using the torchelastic command.
"""
Expand Down Expand Up @@ -438,20 +442,20 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:

def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
# necessary for when the user has passed in a plugin
if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'):
if hasattr(training_type, 'parallel_devices') and getattr(training_type, 'parallel_devices') is None:
training_type.parallel_devices = self.parallel_devices
if hasattr(training_type, 'num_processes'):
training_type.num_processes = len(self.parallel_devices)

if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None:
training_type.cluster_environment = self.select_cluster_environment()

if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None:
if hasattr(training_type, 'num_nodes'):
# set num_nodes for training_type from trainer setting
training_type.num_nodes = self.num_nodes

# Automatically set sync_batchnorm if None.
# Useful for custom plugins.
if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None:
if hasattr(training_type, 'sync_batchnorm'):
# set sync_batchnorm for training_type from trainer setting
training_type.sync_batchnorm = self.sync_batchnorm

return training_type
Expand Down
16 changes: 10 additions & 6 deletions tests/plugins/test_cluster_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,14 @@ def environment_combinations():


@pytest.mark.parametrize(
"plugin_cls", [
"plugin_cls",
[
DDPPlugin,
DDPShardedPlugin,
DDP2Plugin,
pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)),
pytest.param(RPCSequentialPlugin, marks=RunIf(fairscale_pipe=True)),
]
],
)
def test_ranks_available_manual_plugin_selection(plugin_cls):
""" Test that the rank information is readily available after Trainer initialization. """
Expand All @@ -66,10 +67,12 @@ def test_ranks_available_manual_plugin_selection(plugin_cls):
with mock.patch.dict(os.environ, variables):
plugin = plugin_cls(
parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)],
num_nodes=num_nodes,
cluster_environment=cluster,
)
trainer = Trainer(plugins=[plugin])
trainer = Trainer(
plugins=[plugin],
num_nodes=num_nodes,
)
assert rank_zero_only.rank == expected["global_rank"]
assert trainer.global_rank == expected["global_rank"]
assert trainer.local_rank == expected["local_rank"]
Expand All @@ -78,13 +81,14 @@ def test_ranks_available_manual_plugin_selection(plugin_cls):


@pytest.mark.parametrize(
"trainer_kwargs", [
"trainer_kwargs",
[
dict(accelerator="ddp", gpus=[1, 2]),
dict(accelerator="ddp_sharded", gpus=[1, 2]),
dict(accelerator="ddp2", gpus=[1, 2]),
dict(accelerator="ddp_cpu", num_processes=2),
dict(accelerator="ddp_spawn", gpus=[1, 2]),
]
],
)
@mock.patch("torch.cuda.is_available", return_value=True)
@mock.patch("torch.cuda.device_count", return_value=4)
Expand Down