Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions docs/model-quirks.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@ This document outlines special cases and model-specific behaviors that require c

## Gemma-3

### Tied Weights

Weight tying between the embedding layer (`model.embed_tokens`) and output layer (`lm_head`) is currently not respected when using the DTensor policy when TP > 1 (See [this issue](https://github.com/NVIDIA-NeMo/RL/issues/227)). To avoid errors when training these models, we only allow training models with tied weights using the DTensor policy with TP=1. For Llama-3 and Qwen2.5 models, weight-tying is only enabled for the smaller models (< 2B), which can typically be trained without tensor parallelism. For Gemma-3, all model sizes have weight-tying enabled, including the larger models which require tensor parallelism. To support training of these models, we specially handle the Gemma-3 models by allowing training using the DTensor policy with TP > 1.

**Special Handling:**
- We skip the tied weights check for all Gemma-3 models when using the DTensor policy, allowing training using TP > 1.
- We exclude `model.embed_tokens` and `lm_head` from the DTensor tensor parallel plan to maintain weight tying correctly.

### vLLM Initialization

Gemma-3 models have a specific issue with vLLM dummy weight initialization due to a vLLM bug where [a `normalizer` buffer is created](https://github.com/vllm-project/vllm/blob/964472b9667508b1d4a7ed92068ff81740ae0036/vllm/model_executor/models/gemma3.py#L372) that is not present in the Hugging Face model. This causes the `normalizer` buffer to be set to dummy weights at initialization and then never updated with the correct values during model refit. As a workaround for this issue, we do not use dummy weight initialization for vLLM with Gemma-3 models and instead use the `load_format="auto"` setting to load the full weights at initialization.
Expand Down
1 change: 0 additions & 1 deletion examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ checkpointing:
checkpoint_must_save_by: null

policy:
# Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA-NeMo/RL/issues/227)
model_name: "Qwen/Qwen2.5-1.5B"
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ checkpointing:
checkpoint_must_save_by: null

policy:
# Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA-NeMo/RL/issues/227)
model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
Expand Down
11 changes: 2 additions & 9 deletions nemo_rl/models/dtensor/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,19 +342,12 @@ def get_hf_tp_plan(model: PreTrainedModel):
)

# hf tp plan not contain embed_tokens, we add it and set to rowwise_rep
if (
f"{model_prefix}.embed_tokens" not in hf_tp_plan
and not model.config.tie_word_embeddings
):
if f"{model_prefix}.embed_tokens" not in hf_tp_plan:
hf_tp_plan[f"{model_prefix}.embed_tokens"] = "rowwise_rep"

for k, v in hf_tp_plan.items():
# speed up the tp plan for lm_head
if (
k == "lm_head"
and v == "colwise_rep"
and not model.config.tie_word_embeddings
):
if k == "lm_head" and v == "colwise_rep":
hf_tp_plan[k] = ColwiseParallel(
output_layouts=Shard(-1), use_local_output=False
)
Expand Down
6 changes: 0 additions & 6 deletions nemo_rl/models/huggingface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,16 @@ class ModelFlag(Enum):
configuration in different parts of the NeMo RL codebase.

Flags:
SKIP_DTENSOR_TIED_WEIGHTS_CHECK: Models that should skip the tied weights check
for the DTensor Policy even without setting the
NRL_SKIP_TIED_WEIGHT_CHECK flag.
VLLM_LOAD_FORMAT_AUTO: Models that should use the "auto" load format when initializing
VLLM.

Each flag has a `matches` method that determines if the flag applies to a given model_name.
"""

SKIP_DTENSOR_TIED_WEIGHTS_CHECK = auto()
VLLM_LOAD_FORMAT_AUTO = auto()

def matches(self, model_name: str) -> bool:
match self:
case ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK:
return is_gemma_model(model_name)
case ModelFlag.VLLM_LOAD_FORMAT_AUTO:
return is_gemma_model(model_name)
case _:
Expand Down
17 changes: 1 addition & 16 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
AutoModelForSequenceClassification,
AutoTokenizer,
)
from transformers.integrations.accelerate import find_tied_parameters
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM

from nemo_rl.algorithms.interfaces import LossFunction, LossType
Expand All @@ -56,7 +55,6 @@
to_local_if_dtensor,
)
from nemo_rl.models.huggingface.common import (
ModelFlag,
get_flash_attention_kwargs,
pack_sequences,
)
Expand Down Expand Up @@ -267,12 +265,8 @@ def __init__(
self.model.config.pad_token_id = tokenizer.pad_token_id

# caching since this property is not always preserved after FSDP
self.num_tied_weights = len(find_tied_parameters(self.model))
self.skip_tie_check = os.environ.get(
"NRL_SKIP_TIED_WEIGHT_CHECK"
) or ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name)

self.tokenizer = tokenizer

# ------------------------------------------------
# 3) Move to GPU + Composable FSDP
# (Initialize device mesh, shard submodules, then shard entire model)
Expand Down Expand Up @@ -528,15 +522,6 @@ def train(
mbs: Optional[int] = None,
) -> dict[str, Any]:
"""Train the policy on a batch of data with a given loss function."""
# Check if the model has tied weights
if (
self.num_tied_weights != 0
and self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1
and not self.skip_tie_check
):
raise ValueError(
f"Using dtensor policy with tp size {self.cfg['dtensor_cfg']['tensor_parallel_size']} for model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={self.num_tied_weights}) is not supported (https://github.com/NVIDIA-NeMo/RL/issues/227). Please use dtensor policy with tensor parallel == 1 instead."
)
if gbs is None:
gbs = self.cfg["train_global_batch_size"]
if mbs is None:
Expand Down
11 changes: 0 additions & 11 deletions tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,17 +302,6 @@ def test_input_data(tokenizer):
)


@pytest.fixture(scope="module", autouse=True)
def skip_tied_weight_check_for_all():
"""Automatically skip tied weight check for all tests in this module."""
os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1"

yield

# Restore the original value
os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None)


def test_vllm_missing_required_config_key(cluster):
"""Test that an assertion error is raised when a required config key is missing."""
# Create a config missing a required key by removing 'model_name'
Expand Down
9 changes: 0 additions & 9 deletions tests/unit/models/generation/test_vllm_large_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from copy import deepcopy

import pytest
Expand Down Expand Up @@ -63,14 +62,6 @@
}


@pytest.fixture(scope="module", autouse=True)
def skip_tied_weight_check():
"""Automatically skip tied weight check for all tests in this module."""
os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1"
yield
os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None)


@pytest.fixture(scope="function")
def two_node_cluster():
"""Create a virtual cluster with 2 nodes for testing large models."""
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/models/huggingface/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
)
def test_gemma_models(model_name):
assert is_gemma_model(model_name)
assert ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name)
assert ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(model_name)


Expand All @@ -54,5 +53,4 @@ def test_gemma_models(model_name):
)
def test_non_gemma_models(model_name):
assert not is_gemma_model(model_name)
assert not ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name)
assert not ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(model_name)
12 changes: 0 additions & 12 deletions tests/unit/models/policy/test_dtensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pprint

import pytest
Expand Down Expand Up @@ -107,17 +106,6 @@ def create_test_config(
}


@pytest.fixture(scope="module", autouse=True)
def skip_tied_weight_check_for_all():
"""Automatically skip tied weight check for all tests in this module."""
os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1"

yield

# Restore the original value
os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None)


@pytest.fixture(scope="module")
def two_gpu_virtual_cluster():
cluster_name = "test"
Expand Down
8 changes: 0 additions & 8 deletions tests/unit/models/policy/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,6 @@ def create_megatron_test_config(
}


@pytest.fixture(scope="module", autouse=True)
def skip_tied_weight_check_for_all():
"""Automatically skip tied weight check for all tests in this module."""
os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1"
yield
os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None)


@pytest.fixture(scope="function")
def gc_collect():
"""Helper function to force garbage collection after a test"""
Expand Down
11 changes: 0 additions & 11 deletions tests/unit/utils/test_native_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,6 @@ def policy(cluster, tokenizer):
policy.worker_group.shutdown()


@pytest.fixture(scope="module", autouse=True)
def skip_tied_weight_check_for_all():
"""Automatically skip tied weight check for all tests in this module."""
os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1"

yield

# Restore the original value
os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None)


def get_dummy_state_dict(state_dict, dummy_dict={}):
"""Recursively get the dummy state dict
by replacing tensors with random ones of the same shape.
Expand Down