Skip to content

Commit 7c8c7ce

Browse files
kaushikb11rohitgr7pre-commit-ci[bot]
committed
Add strategy argument to Trainer (Lightning-AI#8597)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ee63840 commit 7c8c7ce

File tree

9 files changed

+329
-7
lines changed

9 files changed

+329
-7
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
184184
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))
185185

186186

187+
- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))
188+
189+
187190
### Changed
188191

189192
- Module imports are now catching `ModuleNotFoundError` instead of `ImportError` ([#9867](https://github.com/PyTorchLightning/pytorch-lightning/pull/9867))

pytorch_lightning/trainer/connectors/accelerator_connector.py

+101-3
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
tpu_cores,
9393
ipus,
9494
accelerator,
95+
strategy: Optional[Union[str, TrainingTypePlugin]],
9596
gpus,
9697
gpu_ids,
9798
num_nodes,
@@ -109,14 +110,25 @@ def __init__(
109110
self._distrib_type = None
110111
self._accelerator_type = None
111112

113+
<<<<<<< HEAD
114+
=======
115+
self.strategy = strategy.lower() if isinstance(strategy, str) else strategy
116+
self.distributed_backend = distributed_backend or accelerator
117+
118+
self._init_deterministic(deterministic)
119+
120+
>>>>>>> 05b15e63f (Add `strategy` argument to Trainer (#8597))
112121
self.num_processes = num_processes
113122
self.devices = devices
114123
# `gpus` is the input passed to the Trainer, whereas `gpu_ids` is a list of parsed gpu ids.
115124
self.gpus = gpus
116125
self.parallel_device_ids = gpu_ids
117126
self.tpu_cores = tpu_cores
118127
self.ipus = ipus
128+
<<<<<<< HEAD
119129
self.accelerator = accelerator
130+
=======
131+
>>>>>>> 05b15e63f (Add `strategy` argument to Trainer (#8597))
120132
self.num_nodes = num_nodes
121133
self.sync_batchnorm = sync_batchnorm
122134
self.benchmark = benchmark
@@ -141,16 +153,23 @@ def __init__(
141153

142154
self.plugins = plugins
143155

156+
self._handle_accelerator_and_distributed_backend(distributed_backend, accelerator)
157+
144158
self._validate_accelerator_and_devices()
145159

146160
self._warn_if_devices_flag_ignored()
147161

148162
self.select_accelerator_type()
149-
self.set_distributed_mode()
163+
164+
if self.strategy is not None:
165+
self._set_training_type_plugin()
166+
else:
167+
self.set_distributed_mode()
150168
self.configure_slurm_ddp()
151169

152170
self.handle_given_plugins()
153171
self.update_device_type_if_ipu_plugin()
172+
self.update_device_type_if_training_type_plugin_passed()
154173

155174
self._validate_accelerator_type()
156175
self._set_devices_if_none()
@@ -275,9 +294,56 @@ def _set_devices_if_none(self) -> None:
275294
elif self._accelerator_type == DeviceType.CPU:
276295
self.devices = self.num_processes
277296

297+
def _handle_accelerator_and_distributed_backend(
298+
self, distributed_backend: Optional[str], accelerator: Optional[Union[str, Accelerator]]
299+
) -> None:
300+
if distributed_backend is not None:
301+
rank_zero_deprecation(
302+
f"`Trainer(distributed_backend={distributed_backend})` has been deprecated and will be removed in v1.5."
303+
f" Use `Trainer(strategy={distributed_backend})` instead."
304+
)
305+
if self.strategy is not None:
306+
raise MisconfigurationException(
307+
f"You have passed `Trainer(strategy={self.strategy})` but have"
308+
f" also passed `Trainer(distributed_backend={distributed_backend})`."
309+
f"HINT: Use just `Trainer(strategy={self.strategy})` instead."
310+
)
311+
312+
if accelerator is not None and accelerator in list(DistributedType):
313+
rank_zero_deprecation(
314+
f"Passing {accelerator} `strategy` to the `accelerator` flag in Trainer has been deprecated"
315+
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={accelerator})` instead."
316+
)
317+
if self.strategy is not None:
318+
raise MisconfigurationException(
319+
f"You have passed `Trainer(strategy={self.strategy})` but have"
320+
f" also passed `Trainer(accelerator={accelerator})`."
321+
f"HINT: Use just `Trainer(strategy={self.strategy})` instead."
322+
)
323+
324+
def _set_training_type_plugin(self) -> None:
325+
if isinstance(self.strategy, str) and self.strategy in TrainingTypePluginsRegistry:
326+
self._training_type_plugin = TrainingTypePluginsRegistry.get(self.strategy)
327+
if isinstance(self.strategy, str):
328+
self.set_distributed_mode(self.strategy)
329+
elif isinstance(self.strategy, TrainingTypePlugin):
330+
self._training_type_plugin = self.strategy
331+
278332
def handle_given_plugins(self) -> None:
279333

280-
training_type = None
334+
for plug in self.plugins:
335+
if self.strategy is not None and self._is_plugin_training_type(plug):
336+
raise MisconfigurationException(
337+
f"You have passed `Trainer(strategy={self.strategy})`"
338+
f" and you can only specify one training type plugin, but you have passed {plug} as a plugin."
339+
)
340+
if self._is_plugin_training_type(plug):
341+
rank_zero_deprecation(
342+
f"Passing {plug} `strategy` to the `plugins` flag in Trainer has been deprecated"
343+
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plug})` instead."
344+
)
345+
346+
training_type = self._training_type_plugin or None
281347
checkpoint = None
282348
precision = None
283349
cluster_environment = None
@@ -340,6 +406,10 @@ def handle_given_plugins(self) -> None:
340406
self._checkpoint_io = checkpoint
341407
self._cluster_environment = cluster_environment or self.select_cluster_environment()
342408

409+
@property
410+
def accelerator_types(self) -> List[str]:
411+
return ["auto"] + list(DeviceType)
412+
343413
@property
344414
def precision_plugin(self) -> PrecisionPlugin:
345415
if self._precision_plugin is None:
@@ -530,9 +600,18 @@ def root_gpu(self) -> Optional[int]:
530600
else None
531601
)
532602

603+
@staticmethod
604+
def _is_plugin_training_type(plugin: Union[str, TrainingTypePlugin]) -> bool:
605+
if isinstance(plugin, str) and (plugin in TrainingTypePluginsRegistry or plugin in list(DistributedType)):
606+
return True
607+
return isinstance(plugin, TrainingTypePlugin)
608+
533609
@property
534610
def is_training_type_in_plugins(self) -> bool:
535-
return any(isinstance(plug, str) and plug in TrainingTypePluginsRegistry for plug in self.plugins)
611+
return any(
612+
(isinstance(plug, str) and plug in TrainingTypePluginsRegistry) or isinstance(plug, TrainingTypePlugin)
613+
for plug in self.plugins
614+
)
536615

537616
def select_precision_plugin(self) -> PrecisionPlugin:
538617
# set precision type
@@ -862,6 +941,25 @@ def update_device_type_if_ipu_plugin(self) -> None:
862941
if isinstance(self._training_type_plugin, IPUPlugin) and self._device_type != DeviceType.IPU:
863942
self._device_type = DeviceType.IPU
864943

944+
def update_device_type_if_training_type_plugin_passed(self) -> None:
945+
if isinstance(self.strategy, TrainingTypePlugin) or any(
946+
isinstance(plug, TrainingTypePlugin) for plug in self.plugins
947+
):
948+
if self._accelerator_type is not None:
949+
if self.use_ipu:
950+
self._device_type = DeviceType.IPU
951+
elif self.use_tpu:
952+
self._device_type = DeviceType.TPU
953+
elif self.use_gpu:
954+
self._device_type = DeviceType.GPU
955+
else:
956+
if self.has_ipu:
957+
self._device_type = DeviceType.IPU
958+
elif self.has_tpu:
959+
self._device_type = DeviceType.TPU
960+
elif self.has_gpu:
961+
self._device_type = DeviceType.GPU
962+
865963
def configure_slurm_ddp(self):
866964
# extract SLURM flag vars
867965
# whenever we have the correct number of tasks, we let slurm manage processes

pytorch_lightning/trainer/trainer.py

+5
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __init__(
155155
flush_logs_every_n_steps: Optional[int] = None,
156156
log_every_n_steps: int = 50,
157157
accelerator: Optional[Union[str, Accelerator]] = None,
158+
strategy: Optional[Union[str, TrainingTypePlugin]] = None,
158159
sync_batchnorm: bool = False,
159160
precision: Union[int, str] = 32,
160161
enable_model_summary: bool = True,
@@ -351,6 +352,9 @@ def __init__(
351352
no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint,
352353
training will start from the beginning of the next epoch.
353354
355+
strategy: Supports different training strategies with aliases
356+
as well custom training type plugins.
357+
354358
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.
355359
356360
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
@@ -420,6 +424,7 @@ def __init__(
420424
tpu_cores,
421425
ipus,
422426
accelerator,
427+
strategy,
423428
gpus,
424429
gpu_ids,
425430
num_nodes,

tests/accelerators/test_accelerator_connector.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytorch_lightning.accelerators.gpu import GPUAccelerator
2727
from pytorch_lightning.callbacks import Callback
2828
from pytorch_lightning.plugins import (
29+
DataParallelPlugin,
2930
DDP2Plugin,
3031
DDPPlugin,
3132
DDPShardedPlugin,
@@ -42,7 +43,7 @@
4243
SLURMEnvironment,
4344
TorchElasticEnvironment,
4445
)
45-
from pytorch_lightning.utilities import DistributedType
46+
from pytorch_lightning.utilities import DeviceType, DistributedType
4647
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4748
from tests.helpers.boring_model import BoringModel
4849
from tests.helpers.runif import RunIf
@@ -631,6 +632,78 @@ def test_accelerator_ddp_for_cpu(tmpdir):
631632
assert isinstance(trainer.training_type_plugin, DDPPlugin)
632633

633634

635+
def test_exception_when_strategy_used_with_distributed_backend():
636+
with pytest.raises(MisconfigurationException, match="but have also passed"):
637+
Trainer(distributed_backend="ddp_cpu", strategy="ddp_spawn")
638+
639+
640+
def test_exception_when_strategy_used_with_accelerator():
641+
with pytest.raises(MisconfigurationException, match="but have also passed"):
642+
Trainer(accelerator="ddp", strategy="ddp_spawn")
643+
644+
645+
def test_exception_when_strategy_used_with_plugins():
646+
with pytest.raises(MisconfigurationException, match="only specify one training type plugin, but you have passed"):
647+
Trainer(plugins="ddp_find_unused_parameters_false", strategy="ddp_spawn")
648+
649+
650+
@pytest.mark.parametrize(
651+
["strategy", "plugin"],
652+
[
653+
("ddp_spawn", DDPSpawnPlugin),
654+
("ddp_spawn_find_unused_parameters_false", DDPSpawnPlugin),
655+
("ddp", DDPPlugin),
656+
("ddp_find_unused_parameters_false", DDPPlugin),
657+
],
658+
)
659+
def test_strategy_choice_cpu_str(tmpdir, strategy, plugin):
660+
trainer = Trainer(strategy=strategy, accelerator="cpu", devices=2)
661+
assert isinstance(trainer.training_type_plugin, plugin)
662+
663+
664+
@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin])
665+
def test_strategy_choice_cpu_plugin(tmpdir, plugin):
666+
trainer = Trainer(strategy=plugin(), accelerator="cpu", devices=2)
667+
assert isinstance(trainer.training_type_plugin, plugin)
668+
669+
670+
@RunIf(min_gpus=2)
671+
@pytest.mark.parametrize(
672+
["strategy", "plugin"],
673+
[
674+
("ddp_spawn", DDPSpawnPlugin),
675+
("ddp_spawn_find_unused_parameters_false", DDPSpawnPlugin),
676+
("ddp", DDPPlugin),
677+
("ddp_find_unused_parameters_false", DDPPlugin),
678+
("ddp2", DDP2Plugin),
679+
("dp", DataParallelPlugin),
680+
("ddp_sharded", DDPShardedPlugin),
681+
("ddp_sharded_spawn", DDPSpawnShardedPlugin),
682+
pytest.param("deepspeed", DeepSpeedPlugin, marks=RunIf(deepspeed=True)),
683+
],
684+
)
685+
def test_strategy_choice_gpu_str(tmpdir, strategy, plugin):
686+
trainer = Trainer(strategy=strategy, accelerator="gpu", devices=2)
687+
assert isinstance(trainer.training_type_plugin, plugin)
688+
689+
690+
@RunIf(min_gpus=2)
691+
@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin])
692+
def test_strategy_choice_gpu_plugin(tmpdir, plugin):
693+
trainer = Trainer(strategy=plugin(), accelerator="gpu", devices=2)
694+
assert isinstance(trainer.training_type_plugin, plugin)
695+
696+
697+
@RunIf(min_gpus=2)
698+
@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin])
699+
def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin):
700+
701+
trainer = Trainer(strategy=plugin(), gpus=2)
702+
assert isinstance(trainer.training_type_plugin, plugin)
703+
assert trainer._device_type == DeviceType.GPU
704+
assert isinstance(trainer.accelerator, GPUAccelerator)
705+
706+
634707
@pytest.mark.parametrize("precision", [1, 12, "invalid"])
635708
def test_validate_precision_type(tmpdir, precision):
636709

tests/accelerators/test_ipu.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytorch_lightning.plugins import IPUPlugin, IPUPrecisionPlugin
2525
from pytorch_lightning.trainer.states import RunningStage
2626
from pytorch_lightning.trainer.supporters import CombinedLoader
27-
from pytorch_lightning.utilities import _IPU_AVAILABLE
27+
from pytorch_lightning.utilities import _IPU_AVAILABLE, DeviceType
2828
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2929
from tests.helpers.boring_model import BoringModel
3030
from tests.helpers.datamodules import ClassifDataModule
@@ -120,7 +120,7 @@ def test_warning_if_ipus_not_used(tmpdir):
120120
@RunIf(ipu=True)
121121
def test_no_warning_plugin(tmpdir):
122122
with pytest.warns(None) as record:
123-
Trainer(default_root_dir=tmpdir, plugins=IPUPlugin(training_opts=poptorch.Options()))
123+
Trainer(default_root_dir=tmpdir, strategy=IPUPlugin(training_opts=poptorch.Options()))
124124
assert len(record) == 0
125125

126126

@@ -528,3 +528,18 @@ def test_set_devices_if_none_ipu():
528528

529529
trainer = Trainer(accelerator="ipu", ipus=8)
530530
assert trainer.devices == 8
531+
532+
533+
@RunIf(ipu=True)
534+
def test_strategy_choice_ipu_plugin(tmpdir):
535+
trainer = Trainer(strategy=IPUPlugin(), accelerator="ipu", devices=8)
536+
assert isinstance(trainer.training_type_plugin, IPUPlugin)
537+
538+
539+
@RunIf(ipu=True)
540+
def test_device_type_when_training_plugin_ipu_passed(tmpdir):
541+
542+
trainer = Trainer(strategy=IPUPlugin(), ipus=8)
543+
assert isinstance(trainer.training_type_plugin, IPUPlugin)
544+
assert trainer._device_type == DeviceType.IPU
545+
assert isinstance(trainer.accelerator, IPUAccelerator)

tests/accelerators/test_tpu_backend.py

+13
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,19 @@ def test_ddp_cpu_not_supported_on_tpus():
227227
Trainer(accelerator="ddp_cpu")
228228

229229

230+
@RunIf(tpu=True)
231+
@pytest.mark.parametrize("strategy", ["tpu_spawn", "tpu_spawn_debug"])
232+
def test_strategy_choice_tpu_str(tmpdir, strategy):
233+
trainer = Trainer(strategy=strategy, accelerator="tpu", devices=8)
234+
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)
235+
236+
237+
@RunIf(tpu=True)
238+
def test_strategy_choice_tpu_plugin(tmpdir):
239+
trainer = Trainer(strategy=TPUSpawnPlugin(), accelerator="tpu", devices=8)
240+
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)
241+
242+
230243
@RunIf(tpu=True)
231244
def test_auto_parameters_tying_tpus(tmpdir):
232245

tests/deprecated_api/test_remove_1-7.py

+10
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,16 @@ def test_v1_7_0_deprecate_parameter_validation():
343343
from pytorch_lightning.core.decorators import parameter_validation # noqa: F401
344344

345345

346+
def test_v1_7_0_passing_strategy_to_accelerator_trainer_flag():
347+
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
348+
Trainer(accelerator="ddp_spawn")
349+
350+
351+
def test_v1_7_0_passing_strategy_to_plugins_flag():
352+
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
353+
Trainer(plugins="ddp_spawn")
354+
355+
346356
def test_v1_7_0_weights_summary_trainer(tmpdir):
347357
with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=full\)` is deprecated in v1.5"):
348358
t = Trainer(weights_summary="full")

0 commit comments

Comments
 (0)