@@ -92,6 +92,7 @@ def __init__(
92
92
tpu_cores ,
93
93
ipus ,
94
94
accelerator ,
95
+ strategy : Optional [Union [str , TrainingTypePlugin ]],
95
96
gpus ,
96
97
gpu_ids ,
97
98
num_nodes ,
@@ -109,14 +110,25 @@ def __init__(
109
110
self ._distrib_type = None
110
111
self ._accelerator_type = None
111
112
113
+ < << << << HEAD
114
+ == == == =
115
+ self .strategy = strategy .lower () if isinstance (strategy , str ) else strategy
116
+ self .distributed_backend = distributed_backend or accelerator
117
+
118
+ self ._init_deterministic (deterministic )
119
+
120
+ >> >> >> > 05 b15e63f (Add `strategy` argument to Trainer (#8597))
112
121
self .num_processes = num_processes
113
122
self .devices = devices
114
123
# `gpus` is the input passed to the Trainer, whereas `gpu_ids` is a list of parsed gpu ids.
115
124
self .gpus = gpus
116
125
self .parallel_device_ids = gpu_ids
117
126
self .tpu_cores = tpu_cores
118
127
self .ipus = ipus
128
+ << < << << HEAD
119
129
self .accelerator = accelerator
130
+ == == == =
131
+ >> > >> > > 05 b15e63f (Add `strategy` argument to Trainer (#8597))
120
132
self .num_nodes = num_nodes
121
133
self .sync_batchnorm = sync_batchnorm
122
134
self .benchmark = benchmark
@@ -141,16 +153,23 @@ def __init__(
141
153
142
154
self .plugins = plugins
143
155
156
+ self ._handle_accelerator_and_distributed_backend (distributed_backend , accelerator )
157
+
144
158
self ._validate_accelerator_and_devices ()
145
159
146
160
self ._warn_if_devices_flag_ignored ()
147
161
148
162
self .select_accelerator_type ()
149
- self .set_distributed_mode ()
163
+
164
+ if self .strategy is not None :
165
+ self ._set_training_type_plugin ()
166
+ else :
167
+ self .set_distributed_mode ()
150
168
self .configure_slurm_ddp ()
151
169
152
170
self .handle_given_plugins ()
153
171
self .update_device_type_if_ipu_plugin ()
172
+ self .update_device_type_if_training_type_plugin_passed ()
154
173
155
174
self ._validate_accelerator_type ()
156
175
self ._set_devices_if_none ()
@@ -275,9 +294,56 @@ def _set_devices_if_none(self) -> None:
275
294
elif self ._accelerator_type == DeviceType .CPU :
276
295
self .devices = self .num_processes
277
296
297
+ def _handle_accelerator_and_distributed_backend (
298
+ self , distributed_backend : Optional [str ], accelerator : Optional [Union [str , Accelerator ]]
299
+ ) - > None :
300
+ if distributed_backend is not None :
301
+ rank_zero_deprecation (
302
+ f"`Trainer(distributed_backend={ distributed_backend } )` has been deprecated and will be removed in v1.5."
303
+ f" Use `Trainer(strategy={ distributed_backend } )` instead."
304
+ )
305
+ if self .strategy is not None :
306
+ raise MisconfigurationException (
307
+ f"You have passed `Trainer(strategy={ self .strategy } )` but have"
308
+ f" also passed `Trainer(distributed_backend={ distributed_backend } )`."
309
+ f"HINT: Use just `Trainer(strategy={ self .strategy } )` instead."
310
+ )
311
+
312
+ if accelerator is not None and accelerator in list (DistributedType ):
313
+ rank_zero_deprecation (
314
+ f"Passing { accelerator } `strategy` to the `accelerator` flag in Trainer has been deprecated"
315
+ f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={ accelerator } )` instead."
316
+ )
317
+ if self .strategy is not None :
318
+ raise MisconfigurationException (
319
+ f"You have passed `Trainer(strategy={ self .strategy } )` but have"
320
+ f" also passed `Trainer(accelerator={ accelerator } )`."
321
+ f"HINT: Use just `Trainer(strategy={ self .strategy } )` instead."
322
+ )
323
+
324
+ def _set_training_type_plugin (self ) - > None :
325
+ if isinstance (self .strategy , str ) and self .strategy in TrainingTypePluginsRegistry :
326
+ self ._training_type_plugin = TrainingTypePluginsRegistry .get (self .strategy )
327
+ if isinstance (self .strategy , str ):
328
+ self .set_distributed_mode (self .strategy )
329
+ elif isinstance (self .strategy , TrainingTypePlugin ):
330
+ self ._training_type_plugin = self .strategy
331
+
278
332
def handle_given_plugins (self ) - > None :
279
333
280
- training_type = None
334
+ for plug in self .plugins :
335
+ if self .strategy is not None and self ._is_plugin_training_type (plug ):
336
+ raise MisconfigurationException (
337
+ f"You have passed `Trainer(strategy={ self .strategy } )`"
338
+ f" and you can only specify one training type plugin, but you have passed { plug } as a plugin."
339
+ )
340
+ if self ._is_plugin_training_type (plug ):
341
+ rank_zero_deprecation (
342
+ f"Passing { plug } `strategy` to the `plugins` flag in Trainer has been deprecated"
343
+ f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={ plug } )` instead."
344
+ )
345
+
346
+ training_type = self ._training_type_plugin or None
281
347
checkpoint = None
282
348
precision = None
283
349
cluster_environment = None
@@ -340,6 +406,10 @@ def handle_given_plugins(self) -> None:
340
406
self ._checkpoint_io = checkpoint
341
407
self ._cluster_environment = cluster_environment or self .select_cluster_environment ()
342
408
409
+ @property
410
+ def accelerator_types (self ) - > List [str ]:
411
+ return ["auto "] + list (DeviceType )
412
+
343
413
@property
344
414
def precision_plugin (self ) -> PrecisionPlugin :
345
415
if self ._precision_plugin is None :
@@ -530,9 +600,18 @@ def root_gpu(self) -> Optional[int]:
530
600
else None
531
601
)
532
602
603
+ @staticmethod
604
+ def _is_plugin_training_type (plugin : Union [str , TrainingTypePlugin ]) - > bool :
605
+ if isinstance (plugin , str ) and (plugin in TrainingTypePluginsRegistry or plugin in list (DistributedType )):
606
+ return True
607
+ return isinstance (plugin , TrainingTypePlugin )
608
+
533
609
@property
534
610
def is_training_type_in_plugins (self ) - > bool :
535
- return any (isinstance (plug , str ) and plug in TrainingTypePluginsRegistry for plug in self .plugins )
611
+ return any (
612
+ (isinstance (plug , str ) and plug in TrainingTypePluginsRegistry ) or isinstance (plug , TrainingTypePlugin )
613
+ for plug in self .plugins
614
+ )
536
615
537
616
def select_precision_plugin (self ) - > PrecisionPlugin :
538
617
# set precision type
@@ -862,6 +941,25 @@ def update_device_type_if_ipu_plugin(self) -> None:
862
941
if isinstance (self ._training_type_plugin , IPUPlugin ) and self ._device_type != DeviceType .IPU :
863
942
self ._device_type = DeviceType .IPU
864
943
944
+ def update_device_type_if_training_type_plugin_passed (self ) - > None :
945
+ if isinstance (self .strategy , TrainingTypePlugin ) or any (
946
+ isinstance (plug , TrainingTypePlugin ) for plug in self .plugins
947
+ ):
948
+ if self ._accelerator_type is not None :
949
+ if self .use_ipu :
950
+ self ._device_type = DeviceType .IPU
951
+ elif self .use_tpu :
952
+ self ._device_type = DeviceType .TPU
953
+ elif self .use_gpu :
954
+ self ._device_type = DeviceType .GPU
955
+ else :
956
+ if self .has_ipu :
957
+ self ._device_type = DeviceType .IPU
958
+ elif self .has_tpu :
959
+ self ._device_type = DeviceType .TPU
960
+ elif self .has_gpu :
961
+ self ._device_type = DeviceType .GPU
962
+
865
963
def configure_slurm_ddp (self ):
866
964
# extract SLURM flag vars
867
965
# whenever we have the correct number of tasks, we let slurm manage processes
0 commit comments