Skip to content

Commit

Permalink
Add strategy="auto" support on the 1.9.x branch (#16916)
Browse files Browse the repository at this point in the history
* Fix auto support on the 1.9.x branch

* CHANGELOG

* CHANGELOG

* Fix CHANGELOG
  • Loading branch information
carmocca authored Mar 1, 2023
1 parent 8e55ff7 commit 3bee819
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.9.4] - 2023-02-28
## [1.9.4] - 2023-03-01

### Removed

Expand Down
6 changes: 5 additions & 1 deletion src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.9.4] - 2023-02-28
## [1.9.4] - 2023-03-01

### Added

- Added `Fabric(strategy="auto")` support ([#16916](https://github.com/Lightning-AI/lightning/pull/16916))

### Fixed

Expand Down
8 changes: 6 additions & 2 deletions src/lightning_fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment()

# 4. Instantiate Strategy - Part 1
if self._strategy_flag is None:
if self._strategy_flag in (None, "auto"):
self._strategy_flag = self._choose_strategy()
# In specific cases, ignore user selection and fall back to a different strategy
self._check_strategy_and_fallback()
Expand Down Expand Up @@ -184,7 +184,11 @@ def _check_config_and_set_final_flags(
if strategy is not None:
self._strategy_flag = strategy

if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
if (
strategy not in (None, "auto")
and strategy not in self._registered_strategies
and not isinstance(strategy, Strategy)
):
raise ValueError(
f"You selected an invalid strategy name: `strategy={strategy!r}`."
" It must be either a string or an instance of `lightning.fabric.strategies.Strategy`."
Expand Down
6 changes: 5 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.9.4] - 2023-02-28
## [1.9.4] - 2023-03-01

### Added

- Added `Fabric(strategy="auto")` support. It will choose DDP over DDP-spawn, contrary to `strategy=None` (default) ([#16916](https://github.com/Lightning-AI/lightning/pull/16916))

### Fixed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(
self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment()

# 4. Instantiate Strategy - Part 1
if self._strategy_flag is None:
if self._strategy_flag in (None, "auto"):
self._strategy_flag = self._choose_strategy()
# In specific cases, ignore user selection and fall back to a different strategy
self._check_strategy_and_fallback()
Expand Down Expand Up @@ -273,7 +273,11 @@ def _check_config_and_set_final_flags(
" you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead."
)

if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
if (
strategy not in (None, "auto")
and strategy not in self._registered_strategies
and not isinstance(strategy, Strategy)
):
raise ValueError(
f"You selected an invalid strategy name: `strategy={strategy!r}`."
" It must be either a string or an instance of `pytorch_lightning.strategies.Strategy`."
Expand Down Expand Up @@ -639,6 +643,9 @@ def _choose_strategy(self) -> Union[Strategy, str]:
if len(self._parallel_devices) > 1:
if _IS_INTERACTIVE:
return "ddp_fork"
if self._strategy_flag == "auto":
# None chooses "ddp_spawn" for backwards compatibility, auto chooses "ddp" for future compatibility
return "ddp"
return "ddp_spawn"

return DDPStrategy.strategy_name
Expand Down
16 changes: 16 additions & 0 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,3 +893,19 @@ def get_defaults(cls):
# defaults should match on the intersection of argument names
for name, connector_default in connector_defaults.items():
assert connector_default == fabric_defaults[name]


@mock.patch("lightning_fabric.accelerators.cuda.num_cuda_devices", return_value=2)
@mock.patch("lightning_fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_connector_auto_selection(*_):
connector = _Connector(accelerator="auto", strategy=None, devices="auto")
assert isinstance(connector.accelerator, CUDAAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert isinstance(connector.strategy.launcher, _SubprocessScriptLauncher)
assert connector._devices_flag == [0, 1]

connector = _Connector(accelerator="auto", strategy="auto", devices="auto")
assert isinstance(connector.accelerator, CUDAAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert isinstance(connector.strategy.launcher, _SubprocessScriptLauncher)
assert connector._devices_flag == [0, 1]
Original file line number Diff line number Diff line change
Expand Up @@ -892,3 +892,15 @@ def get_defaults(cls):
for name, connector_default in connector_defaults.items():
name = lut.get(name, name)
assert connector_default == trainer_defaults[name]


def test_connector_auto_selection(cuda_count_2, mps_count_0):
trainer = Trainer(accelerator="auto", strategy=None, devices="auto")
assert isinstance(trainer.accelerator, CUDAAccelerator)
assert isinstance(trainer.strategy, DDPSpawnStrategy)
assert trainer.num_devices == 2

trainer = Trainer(accelerator="auto", strategy="auto", devices="auto")
assert isinstance(trainer.accelerator, CUDAAccelerator)
assert isinstance(trainer.strategy, DDPStrategy)
assert trainer.num_devices == 2

0 comments on commit 3bee819

Please sign in to comment.