Skip to content
Merged
4 changes: 2 additions & 2 deletions nemo_rl/models/dtensor/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _parallelize_gemma3(
Tensor parallelism is not supported for Gemma3 models because of tied word embeddings.
"""
if isinstance(model, Gemma3ForConditionalGeneration):
model_prefix = "language_model"
model_prefix = "model.language_model"
else:
model_prefix = "model"

Expand Down Expand Up @@ -127,7 +127,7 @@ def _parallelize_gemma3(
),
f"{model_prefix}.layers.*.post_feedforward_layernorm": SequenceParallel(),
f"{model_prefix}.norm": SequenceParallel(),
f"{model_prefix}.lm_head": PrepareModuleInput(
"lm_head": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
use_local_output=True,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ automodel = [
"flash-attn==2.7.4.post1",
]
vllm = [
"vllm==0.9.0",
"vllm==0.9.2",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"flash-attn==2.7.4.post1",
]
Expand Down
68 changes: 68 additions & 0 deletions tests/unit/models/dtensor/test_parallelize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from itertools import product
from unittest.mock import MagicMock

import pytest
from torch.distributed.tensor.parallel import ParallelStyle, parallelize_module
from transformers import AutoModelForCausalLM

from nemo_rl.models.dtensor.parallelize import (
_parallelize_gemma3,
_parallelize_llama,
_parallelize_qwen,
)


@pytest.mark.parametrize(
"model_name, parallelize_func, sequence_parallel",
[
(model_name, parallelize_func, sp)
for (model_name, parallelize_func), sp in product(
[
("google/gemma-3-1b-it", _parallelize_gemma3),
("google/gemma-3-4b-it", _parallelize_gemma3),
# ("Qwen/Qwen2.5-1.5B", _parallelize_qwen), # TODO: qwen2 doesn't have q_norm and k_norm, which will cause this test to fail
("Qwen/Qwen3-0.6B", _parallelize_qwen),
("meta-llama/Llama-3.2-1B-Instruct", _parallelize_llama),
],
[True, False],
)
],
)
def test_parallelize_plan_keys(model_name, parallelize_func, sequence_parallel):
"""Tests that the keys in the parallelization plans are valid by mocking parallel styles."""
model = AutoModelForCausalLM.from_pretrained(model_name)
parallel_plan = parallelize_func(model, sequence_parallel=sequence_parallel)

applied_keys = set()

class MockParallelStyle(ParallelStyle):
def __init__(self, key, collector):
self.key = key
self.collector = collector

def _apply(self, module, device_mesh):
self.collector.add(self.key)

mock_plan = {key: MockParallelStyle(key, applied_keys) for key in parallel_plan}
dummy_device_mesh = MagicMock()
dummy_device_mesh.ndim = 1

parallelize_module(model, dummy_device_mesh, mock_plan)

assert set(parallel_plan.keys()) == applied_keys, (
f"Missing keys: {set(parallel_plan.keys()) - applied_keys}"
)
Loading
Loading