diff --git a/docs/testing.md b/docs/testing.md index 8ce97346b9..5a24452813 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -4,9 +4,15 @@ This guide outlines how to test NeMo RL using unit and functional tests, detaili ## Unit Tests -:::{important} -Unit tests require 2 GPUs to test the full suite. -::: +> [!IMPORTANT] +> Unit tests require 2 GPUs to test the full suite. + +> [!TIP] +> Some unit tests require setting up test assets which you can download with +> ```sh +> uv run tests/unit/prepare_unit_test_assets.py +> ``` + ```sh # Run the unit tests using local GPUs diff --git a/nemo_rl/distributed/virtual_cluster.py b/nemo_rl/distributed/virtual_cluster.py index 6e0a75b880..8b7a075d3b 100644 --- a/nemo_rl/distributed/virtual_cluster.py +++ b/nemo_rl/distributed/virtual_cluster.py @@ -45,6 +45,9 @@ class PY_EXECUTABLES: # Use NeMo-RL direct dependencies. BASE = "uv run --locked" + # Use NeMo-RL direct dependencies. + AUTOMODEL = "uv run --locked --extra automodel" + # Use NeMo-RL direct dependencies and vllm. VLLM = "uv run --locked --extra vllm" diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 7de646d47a..f0cdadd1a9 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -357,6 +357,76 @@ def get_hf_tp_plan(model: PreTrainedModel): return hf_tp_plan +def _parallelize_nm5_h( + model, + dp_mesh: DeviceMesh, + tp_mesh: DeviceMesh, + param_dtype: torch.dtype, + sequence_parallel: bool = False, + activation_checkpointing: bool = False, + cpu_offload: bool = False, + custom_parallel_plan: Optional[Union[dict, str]] = None, +) -> torch.distributed.fsdp.FSDPModule: + """Parallelize a NemotronHForCausalLM model across data and tensor parallel dimensions.""" + assert not sequence_parallel, ( + "Sequence parallelism is not supported for NemotronHForCausalLM" + ) + assert custom_parallel_plan is None, ( + "Custom parallel plan is not supported for NemotronHForCausalLM" + ) + + model_tp_plan: dict[str, ParallelStyle] = { + "lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False), + } + + mlp_tp_plan: dict[str, ParallelStyle] = { + "mixer.up_proj": ColwiseParallel(), + "mixer.down_proj": RowwiseParallel(), + } + + layers: torch.nn.ModuleList = model.backbone.layers + parallelize_module(model, tp_mesh, model_tp_plan) + + for layer in model.backbone.layers: + if layer.block_type == "mlp": + parallelize_module(layer, tp_mesh, mlp_tp_plan) + + if activation_checkpointing: + for i in range(len(layers)): + if layers[i].block_type == "mlp": + layers[i] = checkpoint_wrapper(layers[i]) + + if layers[i].block_type == "mamba": + layers[i] = checkpoint_wrapper(layers[i]) + + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, + reduce_dtype=torch.float32, + output_dtype=torch.float32, + ) + + offload_policy = ( + CPUOffloadPolicy(pin_memory=False) + if cpu_offload + else torch.distributed.fsdp.OffloadPolicy + ) + + for layer in layers: + fully_shard( + layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy + ) + + # do not reshard after forward for root model + # because its parameters will be used in backward immediately + return fully_shard( + model, + mesh=dp_mesh, + mp_policy=mp_policy, + offload_policy=offload_policy, + reshard_after_forward=False, + ) + + def _parallelize_model( model: Union[ Qwen2ForCausalLM, @@ -394,7 +464,20 @@ def _parallelize_model( ValueError: If the model type is not supported for parallelization. """ model_cls = type(model) - if model_cls == Gemma3ForConditionalGeneration: + if model_cls.__name__ == "NemotronHForCausalLM": + # need to do something special for nm5, since it's harder to shard the mamba layers + # nm5 is not importable, so we check the __name__ attribute + return _parallelize_nm5_h( + model, + dp_mesh, + tp_mesh, + param_dtype, + sequence_parallel, + activation_checkpointing, + cpu_offload, + custom_parallel_plan, + ) + elif model_cls == Gemma3ForConditionalGeneration: layers: torch.nn.ModuleList = model.language_model.layers # type: ignore num_attention_heads = model.config.text_config.num_attention_heads num_key_value_heads = model.config.text_config.num_key_value_heads diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index aba599923f..992b70b582 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -259,6 +259,7 @@ def __init__( with init_empty_weights(): self.model = model_class.from_config( model_config, + trust_remote_code=True, ) if self.model.config.pad_token_id is None: diff --git a/pyproject.toml b/pyproject.toml index cb7b6f5227..5fc35e2b34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,11 +55,17 @@ automodel = [ # https://github.com/NVIDIA/TransformerEngine/blob/v2.3/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L108 # https://github.com/facebookresearch/xformers/blob/8354497deb2c04c67fbb2e2ad911e86530da0e90/xformers/ops/fmha/flash.py#L76 "flash-attn==2.7.4.post1", + "mamba-ssm", + "causal-conv1d", ] vllm = [ "vllm==0.10.0", # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved "flash-attn==2.7.4.post1", + # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved + "mamba-ssm", + # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved + "causal-conv1d", ] mcore = [ # also need cudnn (https://developer.nvidia.com/cudnn-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=deb_network) @@ -132,6 +138,8 @@ torchvision = [ triton = [ { index = "pytorch-cu128" }, ] +causal-conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d", tag = "v1.5.0.post8" } +mamba-ssm = { git = "https://github.com/state-spaces/mamba.git", rev = "2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4" } [tool.uv.workspace] members = [ @@ -145,7 +153,7 @@ url = "https://download.pytorch.org/whl/cu128" explicit = true [tool.uv] -no-build-isolation-package = ["transformer-engine-torch", "transformer-engine", "flash-attn"] +no-build-isolation-package = ["transformer-engine-torch", "transformer-engine", "flash-attn", "mamba-ssm", "causal-conv1d"] # Always apply the build group since dependencies like TE/mcore/nemo-run require build dependencies # and this lets us assume they are implicitly installed with a simply `uv sync`. Ideally, we'd # avoid including these in the default dependency set, but for now it's required. diff --git a/tests/unit/L0_Unit_Tests_Generation.sh b/tests/unit/L0_Unit_Tests_Generation.sh index 1e33bea35e..3f607cc080 100644 --- a/tests/unit/L0_Unit_Tests_Generation.sh +++ b/tests/unit/L0_Unit_Tests_Generation.sh @@ -15,6 +15,8 @@ #!/bin/bash set -xeuo pipefail # Exit immediately if a command exits with a non-zero status +uv run tests/unit/prepare_unit_test_assets.py + cd /opt/nemo-rl uv run --no-sync bash -x ./tests/run_unit.sh unit/models/generation/ --cov=nemo_rl --cov-report=term-missing --cov-report=json --hf-gated diff --git a/tests/unit/L0_Unit_Tests_Other.sh b/tests/unit/L0_Unit_Tests_Other.sh index e86d6f887a..a639730044 100644 --- a/tests/unit/L0_Unit_Tests_Other.sh +++ b/tests/unit/L0_Unit_Tests_Other.sh @@ -15,6 +15,8 @@ #!/bin/bash set -xeuo pipefail # Exit immediately if a command exits with a non-zero status +uv run tests/unit/prepare_unit_test_assets.py + cd /opt/nemo-rl uv run --no-sync bash -x ./tests/run_unit.sh unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --cov=nemo_rl --cov-report=term-missing --cov-report=json --hf-gated diff --git a/tests/unit/L0_Unit_Tests_Policy.sh b/tests/unit/L0_Unit_Tests_Policy.sh index 4df69728ff..6fe9309fe6 100644 --- a/tests/unit/L0_Unit_Tests_Policy.sh +++ b/tests/unit/L0_Unit_Tests_Policy.sh @@ -15,6 +15,8 @@ #!/bin/bash set -xeuo pipefail # Exit immediately if a command exits with a non-zero status +uv run tests/unit/prepare_unit_test_assets.py + cd /opt/nemo-rl uv run --no-sync bash -x ./tests/run_unit.sh unit/models/policy/ --cov=nemo_rl --cov-report=term-missing --cov-report=json --hf-gated diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1346a1173d..978131969a 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -576,3 +576,49 @@ def tiny_gemma3_model_path(): tokenizer.save_pretrained(model_path) del model, tokenizer yield model_path + + +def _build_tiny_nemotron5_h_checkpoint(model_path: str) -> None: + import shutil + + from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + + config = AutoConfig.from_pretrained( + "nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True + ) + config.hybrid_override_pattern = "M*-" + config.num_hidden_layers = 3 + config.intermediate_size = 32 + config.hidden_size = 256 + config.num_attention_heads = 8 + config.mamba_num_heads = 8 + config.num_key_value_heads = 8 + config.n_groups = 1 + + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + "nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True + ) + + shutil.rmtree(model_path, ignore_errors=True) + model.save_pretrained(model_path) + tokenizer.save_pretrained(model_path) + + +@pytest.fixture(scope="session") +def tiny_nemotron5_h_model_path(): + """Fixture that returns a path to a tiny nemotron model with a dummy tokenizer. + + If the asset hasn't been prepared by the prepare script, skip the tests that require it. + """ + model_path = os.path.join( + TEST_ASSETS_DIR, "tiny_nemotron5_h_with_nemotron_tokenizer" + ) + + config_file = os.path.join(model_path, "config.json") + if not os.path.exists(config_file): + pytest.skip( + "Tiny Nemotron-H test asset not prepared. Run `uv run tests/unit/prepare_unit_test_assets.py` first." + ) + + yield model_path diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index f6d1a7c2a8..457a4a8e3e 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -328,6 +328,13 @@ def training_setup(request, two_gpu_virtual_cluster): ("tiny_gemma3_model_path", 1, 1, False, True, True), ("tiny_gemma3_model_path", 1, 1, True, True, True), # CP doesn't support gemma3 due to spda input has attent_mask != None. + # Nemotron-H doesn't support SP https://github.com/NVIDIA-NeMo/RL/issues/881 + # ("tiny_nemotron5_h_model_path", 1, 1, True, True, False), + # ("tiny_nemotron5_h_model_path", 1, 1, True, False, True), + # ("tiny_nemotron5_h_model_path", 1, 1, True, True, True), + ("tiny_nemotron5_h_model_path", 1, 1, False, False, False), + ("tiny_nemotron5_h_model_path", 1, 1, False, True, True), + # nemotron5_h doesn't support cp ], indirect=True, ) diff --git a/tests/unit/prepare_unit_test_assets.py b/tests/unit/prepare_unit_test_assets.py new file mode 100644 index 0000000000..6cb8344c55 --- /dev/null +++ b/tests/unit/prepare_unit_test_assets.py @@ -0,0 +1,98 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script exists to help load any unit asset that requires special handling. + +The initial reason for this was to help with Nemotron-H which has a requirement +to have mamaba-ssm in the base environment in order to initialize a dummy model. Since +the unit tests should be runable with the base environment (without mamba-ssm), +we use ray.remotes to build the asset here. We do this outside of a fixture +like the other test assets because this one sometimes takes a while to build. This +extra setup time can sometimes cause timeouts in the unit tests if unlucky. +""" + +import os + +import ray + +from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES +from nemo_rl.utils.venvs import create_local_venv + +TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) +TEST_ASSETS_DIR = os.path.join(TESTS_DIR, "test_assets") + + +def build_tiny_nemotron5_h_checkpoint(model_path: str) -> None: + import shutil + + from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + + config = AutoConfig.from_pretrained( + "nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True + ) + config.hybrid_override_pattern = "M*-" + config.num_hidden_layers = 3 + config.intermediate_size = 32 + config.hidden_size = 256 + config.num_attention_heads = 8 + config.mamba_num_heads = 8 + config.num_key_value_heads = 8 + config.n_groups = 1 + + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + "nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True + ) + + shutil.rmtree(model_path, ignore_errors=True) + model.save_pretrained(model_path) + tokenizer.save_pretrained(model_path) + print(f"✓ Built tiny Nemotron-H asset at: {model_path}") + + +def main() -> None: + os.makedirs(TEST_ASSETS_DIR, exist_ok=True) + + target = os.path.join(TEST_ASSETS_DIR, "tiny_nemotron5_h_with_nemotron_tokenizer") + + # Create Automodel env venv + automodel_python = create_local_venv( + py_executable=PY_EXECUTABLES.AUTOMODEL, venv_name="automodel_env" + ) + + ############################################################################ + # Add other remote calls here + ############################################################################ + # Submit as list of remote calls and wait individually + remote_calls = [ + ray.remote(build_tiny_nemotron5_h_checkpoint) + .options( + num_gpus=0.01, # tiny reservation to satisfy CUDA-inspecting deps + runtime_env={"py_executable": automodel_python}, + name="build-nemotron5h", + ) + .remote(target) + ] + + for obj_ref in remote_calls: + ray.get(obj_ref) + + +if __name__ == "__main__": + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True, include_dashboard=False) + try: + main() + finally: + ray.shutdown() diff --git a/uv.lock b/uv.lock index c2a8c6cfe9..9cd8ecb5e3 100644 --- a/uv.lock +++ b/uv.lock @@ -402,6 +402,16 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/96/d32b941a501ab566a16358d68b6eb4e4acc373fab3c3c4d7d9e649f7b4bb/catalogue-2.0.10-py3-none-any.whl", hash = "sha256:58c2de0020aa90f4a2da7dfad161bf7b3b054c86a5f09fcedc0b2b740c109a9f", size = 17325, upload-time = "2023-09-25T06:29:23.337Z" }, ] +[[package]] +name = "causal-conv1d" +version = "1.5.0.post8" +source = { git = "https://github.com/Dao-AILab/causal-conv1d?tag=v1.5.0.post8#82867a9d2e6907cc0f637ac6aff318f696838548" } +dependencies = [ + { name = "ninja" }, + { name = "packaging" }, + { name = "torch" }, +] + [[package]] name = "cbor2" version = "5.6.5" @@ -2023,6 +2033,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, ] +[[package]] +name = "mamba-ssm" +version = "2.2.4" +source = { git = "https://github.com/state-spaces/mamba.git?rev=2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4#2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4" } +dependencies = [ + { name = "einops" }, + { name = "ninja" }, + { name = "packaging" }, + { name = "setuptools" }, + { name = "torch" }, + { name = "transformers" }, + { name = "triton", version = "3.3.0", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'linux'" }, + { name = "triton", version = "3.3.1", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" }, +] + [[package]] name = "markdown" version = "3.8.2" @@ -2542,7 +2567,9 @@ dependencies = [ [package.optional-dependencies] automodel = [ + { name = "causal-conv1d" }, { name = "flash-attn" }, + { name = "mamba-ssm" }, ] mcore = [ { name = "flash-attn" }, @@ -2552,7 +2579,9 @@ mcore = [ { name = "vllm" }, ] vllm = [ + { name = "causal-conv1d" }, { name = "flash-attn" }, + { name = "mamba-ssm" }, { name = "vllm" }, ] @@ -2592,6 +2621,8 @@ test = [ requires-dist = [ { name = "accelerate", specifier = ">=0.26" }, { name = "blobfile" }, + { name = "causal-conv1d", marker = "extra == 'automodel'", git = "https://github.com/Dao-AILab/causal-conv1d?tag=v1.5.0.post8" }, + { name = "causal-conv1d", marker = "extra == 'vllm'", git = "https://github.com/Dao-AILab/causal-conv1d?tag=v1.5.0.post8" }, { name = "colored", specifier = "==2.2.3" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "debugpy" }, @@ -2599,6 +2630,8 @@ requires-dist = [ { name = "flash-attn", marker = "extra == 'mcore'", specifier = "==2.7.4.post1" }, { name = "flash-attn", marker = "extra == 'vllm'", specifier = "==2.7.4.post1" }, { name = "hydra-core" }, + { name = "mamba-ssm", marker = "extra == 'automodel'", git = "https://github.com/state-spaces/mamba.git?rev=2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4" }, + { name = "mamba-ssm", marker = "extra == 'vllm'", git = "https://github.com/state-spaces/mamba.git?rev=2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4" }, { name = "math-verify" }, { name = "matplotlib" }, { name = "megatron-core", marker = "extra == 'mcore'", editable = "3rdparty/Megatron-LM-workspace" },