From 4ebac4878e3bfe971f3d7c518ea2e8996d2c80bb Mon Sep 17 00:00:00 2001 From: ruit Date: Sun, 10 Aug 2025 07:55:54 -0700 Subject: [PATCH] Add TP to embed_tokens and lm_head for Gemma models Signed-off-by: ruit --- nemo_rl/models/dtensor/parallelize.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index f0cdadd1a9..25ecaf8051 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -31,8 +31,6 @@ from torch.distributed.tensor.parallel import ( ColwiseParallel, ParallelStyle, - PrepareModuleInput, - PrepareModuleOutput, RowwiseParallel, SequenceParallel, parallelize_module, @@ -93,18 +91,14 @@ def _parallelize_gemma3( model: Union[Gemma3ForCausalLM, Gemma3ForConditionalGeneration], sequence_parallel: bool = False, ) -> dict[str, ParallelStyle]: - """Parallelizes a Gemma3ForCausalLM model across data parallel dimensions. - - Tensor parallelism is not supported for Gemma3 models because of tied word embeddings. - """ + """Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions.""" if isinstance(model, Gemma3ForConditionalGeneration): model_prefix = "model.language_model" else: model_prefix = "model" - # For gemma3 models, we don't include the model.embed_tokens and lm_head in the - # parallelization plans because they have tied weights. base_model_tp_plan: dict[str, ParallelStyle] = { + f"{model_prefix}.embed_tokens": RowwiseParallel(input_layouts=Replicate()), f"{model_prefix}.layers.*.self_attn.q_proj": ColwiseParallel(), f"{model_prefix}.layers.*.self_attn.k_proj": ColwiseParallel(), f"{model_prefix}.layers.*.self_attn.v_proj": ColwiseParallel(), @@ -112,13 +106,12 @@ def _parallelize_gemma3( f"{model_prefix}.layers.*.mlp.up_proj": ColwiseParallel(), f"{model_prefix}.layers.*.mlp.gate_proj": ColwiseParallel(), f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel(), + "lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False), } base_model_sp_plan = { - f"{model_prefix}.embed_tokens": PrepareModuleOutput( - output_layouts=Replicate(), - desired_output_layouts=Shard(1), - use_local_output=False, + f"{model_prefix}.embed_tokens": RowwiseParallel( + input_layouts=Replicate(), output_layouts=Shard(1) ), f"{model_prefix}.rotary_emb": RotaryEmbedParallel(use_local_output=True), f"{model_prefix}.rotary_emb_local": RotaryEmbedParallel(use_local_output=True), @@ -133,10 +126,8 @@ def _parallelize_gemma3( ), f"{model_prefix}.layers.*.post_feedforward_layernorm": SequenceParallel(), f"{model_prefix}.norm": SequenceParallel(), - "lm_head": PrepareModuleInput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - use_local_output=True, + "lm_head": ColwiseParallel( + input_layouts=Shard(1), output_layouts=Shard(-1), use_local_output=False ), }