4
4
from collections import OrderedDict
5
5
from contextlib import ExitStack
6
6
from pathlib import Path
7
- from typing import TYPE_CHECKING , Any , ContextManager , Dict , List , Mapping , Optional , TypeVar , Union , cast
7
+ from typing import TYPE_CHECKING , Any , ContextManager , Dict , List , Literal , Mapping , Optional , TypeVar , Union , cast
8
8
9
9
import pytorch_lightning as pl
10
10
import torch
11
11
import torch .distributed
12
12
from lightning_fabric .plugins import CheckpointIO , ClusterEnvironment
13
13
from lightning_fabric .utilities .optimizer import _optimizers_to_device
14
+ from megatron .core .distributed import DistributedDataParallelConfig
14
15
from pytorch_lightning .accelerators import CPUAccelerator
15
16
from pytorch_lightning .callbacks .progress import TQDMProgressBar
16
17
from pytorch_lightning .loops import _AutomaticOptimization , evaluation_loop , fit_loop , prediction_loop
38
39
ConfigT = TypeVar ("ConfigT" )
39
40
40
41
42
+ DDPLiteral = Literal ["megatron" , "pytorch" ]
43
+
44
+
41
45
class MegatronStrategy (DDPStrategy , io .IOMixin ):
42
46
"""Megatron plugin for Pytorch Lightning.
43
47
@@ -58,11 +62,11 @@ def __init__(
58
62
parallel_devices : Optional [List [torch .device ]] = None ,
59
63
cluster_environment = None , # TODO: Add type-hint
60
64
checkpoint_io = None , # TODO: Add type-hint
61
- no_ddp_communication_hook : bool = True ,
62
65
find_unused_parameters : bool = False ,
63
66
enable_nemo_ckpt_io : bool = True ,
64
67
ckpt_type : TrainerCkptProtocol = TrainerCheckpoint ,
65
68
ckpt_include_optimizer : bool = False ,
69
+ ddp : Union [DDPLiteral , DistributedDataParallelConfig ] = "megatron" ,
66
70
lazy_init : bool = False ,
67
71
** kwargs ,
68
72
) -> None :
@@ -73,7 +77,7 @@ def __init__(
73
77
find_unused_parameters = find_unused_parameters ,
74
78
** kwargs ,
75
79
)
76
- self . no_ddp_communication_hook = no_ddp_communication_hook
80
+
77
81
self .megatron_callbacks = CallbackConnector ()
78
82
self .data_sampler : Optional ['DataSampler' ] = data_sampler
79
83
self .tensor_model_parallel_size = tensor_model_parallel_size
@@ -85,6 +89,16 @@ def __init__(
85
89
self .lazy_init = lazy_init
86
90
self .ckpt_include_optimizer = ckpt_include_optimizer
87
91
92
+ if ddp == "megatron" :
93
+ self .ddp_config = DistributedDataParallelConfig ()
94
+ elif isinstance (ddp , DistributedDataParallelConfig ):
95
+ self .ddp_config = ddp
96
+ elif ddp == "pytorch" :
97
+ self .ddp_config = None
98
+ self .no_ddp_communication_hook = False
99
+ else :
100
+ raise ValueError (f"Invalid DDP type: { ddp } " )
101
+
88
102
# used in NVIDIA NGC PyTorch containers
89
103
_strategy_lib .enable_nvidia_optimizations ()
90
104
@@ -153,6 +167,9 @@ def setup(self, trainer: pl.Trainer) -> None:
153
167
154
168
# set up optimizers after the wrapped module has been moved to the device
155
169
self .setup_optimizers (trainer )
170
+
171
+ # TODO: Throw an execption if we have a mcore optimizer and no ddp_config
172
+
156
173
if hasattr (self .precision_plugin , "convert_optimizer" ):
157
174
_optimizers = [* self .optimizers ]
158
175
_optimizers [0 ] = self .precision_plugin .convert_optimizer (self .optimizers [0 ])
@@ -204,6 +221,7 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None:
204
221
precision_plugin = self .precision_plugin ,
205
222
vp_size = self .virtual_pipeline_model_parallel_size ,
206
223
cpu = isinstance (trainer .accelerator , CPUAccelerator ),
224
+ ddp_config = self .ddp_config ,
207
225
)
208
226
self .model = self .megatron_parallel
209
227
self .model .trainer = trainer
@@ -212,6 +230,10 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None:
212
230
self .model = self .precision_plugin .convert_module (self .model )
213
231
self .model .callbacks .add (getattr (trainer , "callbacks" ))
214
232
233
+ if hasattr (self , "optimizers" ) and self .optimizers :
234
+ for optimizer in self .optimizers :
235
+ self .model .callbacks .add (optimizer )
236
+
215
237
if self .data_sampler :
216
238
self .model .callbacks .add (self .data_sampler )
217
239
@@ -223,10 +245,11 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None:
223
245
def configure_ddp (self ) -> None :
224
246
logging .debug (f"{ self .__class__ .__name__ } : configuring MegatronParallel" )
225
247
self .model = self ._setup_model (self .model )
226
- self ._register_ddp_hooks ()
248
+ if self .ddp_config is None :
249
+ self ._register_ddp_hooks ()
227
250
228
251
@override
229
- def _setup_model (self , model : nn .Module ) -> DistributedDataParallel :
252
+ def _setup_model (self , model : nn .Module ) -> nn . Module :
230
253
"""Only called when we need to wrap the model for pytorch's ddp."""
231
254
from megatron .core import parallel_state
232
255
@@ -236,16 +259,19 @@ def _setup_model(self, model: nn.Module) -> DistributedDataParallel:
236
259
if app_state .model_parallel_size is not None :
237
260
self ._ddp_kwargs ["process_group" ] = parallel_state .get_data_parallel_group ()
238
261
239
- dist_data_parallel : DistributedDataParallel = super ()._setup_model (model )
240
- if self .no_ddp_communication_hook :
241
- # When using custom gradient accumulation and allreduce, disable
242
- # DDP communication hook that works on the gradient bucket.
243
- # Instead, use the custom gradient function and communication hook,
244
- # which is defined in the master optimizer wrapper.
245
- dist_data_parallel .require_backward_grad_sync = False
246
- dist_data_parallel .register_comm_hook (None , noop_hook )
262
+ # Only wrap the model if we are not using Megatron's DDP
263
+ if not self .ddp_config :
264
+ dist_data_parallel : DistributedDataParallel = super ()._setup_model (model )
265
+ if self .no_ddp_communication_hook :
266
+ # When using custom gradient accumulation and allreduce, disable
267
+ # DDP communication hook that works on the gradient bucket.
268
+ # Instead, use the custom gradient function and communication hook,
269
+ # which is defined in the master optimizer wrapper.
270
+ dist_data_parallel .require_backward_grad_sync = False
271
+ dist_data_parallel .register_comm_hook (None , noop_hook )
272
+ model = dist_data_parallel
247
273
248
- return dist_data_parallel
274
+ return model
249
275
250
276
def _setup_parallel_ranks (self ) -> None :
251
277
self .set_world_ranks ()
@@ -260,7 +286,7 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
260
286
kwargs = self ._update_step_kwargs (dataloader_iter , kwargs , "training" )
261
287
262
288
with self .precision_plugin .train_step_context (): # TODO: Do we need this?
263
- return self .model (dataloader_iter , * args , ** kwargs )
289
+ return self .model (dataloader_iter , forward_only = False , * args , ** kwargs )
264
290
265
291
@override
266
292
def validation_step (self , dataloader_iter , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
@@ -269,7 +295,7 @@ def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OU
269
295
kwargs = self ._update_step_kwargs (dataloader_iter , kwargs , "validation" )
270
296
271
297
with self .precision_plugin .val_step_context (): # TODO: Do we need this?
272
- return self .model (dataloader_iter , * args , ** kwargs )
298
+ return self .model (dataloader_iter , forward_only = True , * args , ** kwargs )
273
299
274
300
@override
275
301
def test_step (self , dataloader_iter , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
@@ -278,7 +304,7 @@ def test_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
278
304
kwargs = self ._update_step_kwargs (dataloader_iter , kwargs , "test" )
279
305
280
306
with self .precision_plugin .test_step_context (): # TODO: Do we need this?
281
- return self .model (dataloader_iter , * args , ** kwargs )
307
+ return self .model (dataloader_iter , forward_only = True , * args , ** kwargs )
282
308
283
309
@override
284
310
def predict_step (self , dataloader_iter , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
@@ -287,7 +313,7 @@ def predict_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPU
287
313
kwargs = self ._update_step_kwargs (dataloader_iter , kwargs , "predict" )
288
314
289
315
with self .precision_plugin .predict_step_context (): # TODO: Do we need this?
290
- return self .model (dataloader_iter , * args , ** kwargs )
316
+ return self .model (dataloader_iter , forward_only = True , * args , ** kwargs )
291
317
292
318
@override
293
319
def teardown (self ) -> None :
0 commit comments