diff --git a/docs/model-quirks.md b/docs/model-quirks.md index 7a79a95e66..52869bf04d 100644 --- a/docs/model-quirks.md +++ b/docs/model-quirks.md @@ -4,14 +4,6 @@ This document outlines special cases and model-specific behaviors that require c ## Gemma-3 -### Tied Weights - -Weight tying between the embedding layer (`model.embed_tokens`) and output layer (`lm_head`) is currently not respected when using the DTensor policy when TP > 1 (See [this issue](https://github.com/NVIDIA-NeMo/RL/issues/227)). To avoid errors when training these models, we only allow training models with tied weights using the DTensor policy with TP=1. For Llama-3 and Qwen2.5 models, weight-tying is only enabled for the smaller models (< 2B), which can typically be trained without tensor parallelism. For Gemma-3, all model sizes have weight-tying enabled, including the larger models which require tensor parallelism. To support training of these models, we specially handle the Gemma-3 models by allowing training using the DTensor policy with TP > 1. - -**Special Handling:** -- We skip the tied weights check for all Gemma-3 models when using the DTensor policy, allowing training using TP > 1. -- We exclude `model.embed_tokens` and `lm_head` from the DTensor tensor parallel plan to maintain weight tying correctly. - ### vLLM Initialization Gemma-3 models have a specific issue with vLLM dummy weight initialization due to a vLLM bug where [a `normalizer` buffer is created](https://github.com/vllm-project/vllm/blob/964472b9667508b1d4a7ed92068ff81740ae0036/vllm/model_executor/models/gemma3.py#L372) that is not present in the Hugging Face model. This causes the `normalizer` buffer to be set to dummy weights at initialization and then never updated with the correct values during model refit. As a workaround for this issue, we do not use dummy weight initialization for vLLM with Gemma-3 models and instead use the `load_format="auto"` setting to load the full weights at initialization. diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index f4f668db91..5d3daff3aa 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -32,7 +32,6 @@ checkpointing: checkpoint_must_save_by: null policy: - # Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA-NeMo/RL/issues/227) model_name: "Qwen/Qwen2.5-1.5B" tokenizer: name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default diff --git a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml index a8a21fac83..7513390aaa 100644 --- a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml @@ -32,7 +32,6 @@ checkpointing: checkpoint_must_save_by: null policy: - # Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA-NeMo/RL/issues/227) model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" tokenizer: name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index e2af748d71..7de646d47a 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -342,19 +342,12 @@ def get_hf_tp_plan(model: PreTrainedModel): ) # hf tp plan not contain embed_tokens, we add it and set to rowwise_rep - if ( - f"{model_prefix}.embed_tokens" not in hf_tp_plan - and not model.config.tie_word_embeddings - ): + if f"{model_prefix}.embed_tokens" not in hf_tp_plan: hf_tp_plan[f"{model_prefix}.embed_tokens"] = "rowwise_rep" for k, v in hf_tp_plan.items(): # speed up the tp plan for lm_head - if ( - k == "lm_head" - and v == "colwise_rep" - and not model.config.tie_word_embeddings - ): + if k == "lm_head" and v == "colwise_rep": hf_tp_plan[k] = ColwiseParallel( output_layouts=Shard(-1), use_local_output=False ) diff --git a/nemo_rl/models/huggingface/common.py b/nemo_rl/models/huggingface/common.py index c057f6d89a..ad26e36327 100644 --- a/nemo_rl/models/huggingface/common.py +++ b/nemo_rl/models/huggingface/common.py @@ -39,22 +39,16 @@ class ModelFlag(Enum): configuration in different parts of the NeMo RL codebase. Flags: - SKIP_DTENSOR_TIED_WEIGHTS_CHECK: Models that should skip the tied weights check - for the DTensor Policy even without setting the - NRL_SKIP_TIED_WEIGHT_CHECK flag. VLLM_LOAD_FORMAT_AUTO: Models that should use the "auto" load format when initializing VLLM. Each flag has a `matches` method that determines if the flag applies to a given model_name. """ - SKIP_DTENSOR_TIED_WEIGHTS_CHECK = auto() VLLM_LOAD_FORMAT_AUTO = auto() def matches(self, model_name: str) -> bool: match self: - case ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK: - return is_gemma_model(model_name) case ModelFlag.VLLM_LOAD_FORMAT_AUTO: return is_gemma_model(model_name) case _: diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 190c366ebf..1426778fc3 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -42,7 +42,6 @@ AutoModelForSequenceClassification, AutoTokenizer, ) -from transformers.integrations.accelerate import find_tied_parameters from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM from nemo_rl.algorithms.interfaces import LossFunction, LossType @@ -56,7 +55,6 @@ to_local_if_dtensor, ) from nemo_rl.models.huggingface.common import ( - ModelFlag, get_flash_attention_kwargs, pack_sequences, ) @@ -267,12 +265,8 @@ def __init__( self.model.config.pad_token_id = tokenizer.pad_token_id # caching since this property is not always preserved after FSDP - self.num_tied_weights = len(find_tied_parameters(self.model)) - self.skip_tie_check = os.environ.get( - "NRL_SKIP_TIED_WEIGHT_CHECK" - ) or ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name) - self.tokenizer = tokenizer + # ------------------------------------------------ # 3) Move to GPU + Composable FSDP # (Initialize device mesh, shard submodules, then shard entire model) @@ -528,15 +522,6 @@ def train( mbs: Optional[int] = None, ) -> dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" - # Check if the model has tied weights - if ( - self.num_tied_weights != 0 - and self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1 - and not self.skip_tie_check - ): - raise ValueError( - f"Using dtensor policy with tp size {self.cfg['dtensor_cfg']['tensor_parallel_size']} for model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={self.num_tied_weights}) is not supported (https://github.com/NVIDIA-NeMo/RL/issues/227). Please use dtensor policy with tensor parallel == 1 instead." - ) if gbs is None: gbs = self.cfg["train_global_batch_size"] if mbs is None: diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index a5ab34241a..c2c60086b9 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -302,17 +302,6 @@ def test_input_data(tokenizer): ) -@pytest.fixture(scope="module", autouse=True) -def skip_tied_weight_check_for_all(): - """Automatically skip tied weight check for all tests in this module.""" - os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1" - - yield - - # Restore the original value - os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None) - - def test_vllm_missing_required_config_key(cluster): """Test that an assertion error is raised when a required config key is missing.""" # Create a config missing a required key by removing 'model_name' diff --git a/tests/unit/models/generation/test_vllm_large_model.py b/tests/unit/models/generation/test_vllm_large_model.py index 7b93ef46d1..1b7387e832 100644 --- a/tests/unit/models/generation/test_vllm_large_model.py +++ b/tests/unit/models/generation/test_vllm_large_model.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from copy import deepcopy import pytest @@ -63,14 +62,6 @@ } -@pytest.fixture(scope="module", autouse=True) -def skip_tied_weight_check(): - """Automatically skip tied weight check for all tests in this module.""" - os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1" - yield - os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None) - - @pytest.fixture(scope="function") def two_node_cluster(): """Create a virtual cluster with 2 nodes for testing large models.""" diff --git a/tests/unit/models/huggingface/test_common.py b/tests/unit/models/huggingface/test_common.py index 95da64b0b4..e1f7b948aa 100644 --- a/tests/unit/models/huggingface/test_common.py +++ b/tests/unit/models/huggingface/test_common.py @@ -39,7 +39,6 @@ ) def test_gemma_models(model_name): assert is_gemma_model(model_name) - assert ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name) assert ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(model_name) @@ -54,5 +53,4 @@ def test_gemma_models(model_name): ) def test_non_gemma_models(model_name): assert not is_gemma_model(model_name) - assert not ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name) assert not ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(model_name) diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index ab876eb214..f6d1a7c2a8 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import pprint import pytest @@ -107,17 +106,6 @@ def create_test_config( } -@pytest.fixture(scope="module", autouse=True) -def skip_tied_weight_check_for_all(): - """Automatically skip tied weight check for all tests in this module.""" - os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1" - - yield - - # Restore the original value - os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None) - - @pytest.fixture(scope="module") def two_gpu_virtual_cluster(): cluster_name = "test" diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 20c31324e0..5b16b0b28a 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -133,14 +133,6 @@ def create_megatron_test_config( } -@pytest.fixture(scope="module", autouse=True) -def skip_tied_weight_check_for_all(): - """Automatically skip tied weight check for all tests in this module.""" - os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1" - yield - os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None) - - @pytest.fixture(scope="function") def gc_collect(): """Helper function to force garbage collection after a test""" diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index 69493da3b3..88003941cb 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -130,17 +130,6 @@ def policy(cluster, tokenizer): policy.worker_group.shutdown() -@pytest.fixture(scope="module", autouse=True) -def skip_tied_weight_check_for_all(): - """Automatically skip tied weight check for all tests in this module.""" - os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1" - - yield - - # Restore the original value - os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None) - - def get_dummy_state_dict(state_dict, dummy_dict={}): """Recursively get the dummy state dict by replacing tensors with random ones of the same shape.