Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions docs/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@ This guide outlines how to test NeMo RL using unit and functional tests, detaili

## Unit Tests

:::{important}
Unit tests require 2 GPUs to test the full suite.
:::
> [!IMPORTANT]
> Unit tests require 2 GPUs to test the full suite.

> [!TIP]
> Some unit tests require setting up test assets which you can download with
> ```sh
> uv run tests/unit/prepare_unit_test_assets.py
> ```


```sh
# Run the unit tests using local GPUs
Expand Down
3 changes: 3 additions & 0 deletions nemo_rl/distributed/virtual_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class PY_EXECUTABLES:
# Use NeMo-RL direct dependencies.
BASE = "uv run --locked"

# Use NeMo-RL direct dependencies.
AUTOMODEL = "uv run --locked --extra automodel"

# Use NeMo-RL direct dependencies and vllm.
VLLM = "uv run --locked --extra vllm"

Expand Down
85 changes: 84 additions & 1 deletion nemo_rl/models/dtensor/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,76 @@ def get_hf_tp_plan(model: PreTrainedModel):
return hf_tp_plan


def _parallelize_nm5_h(
model,
dp_mesh: DeviceMesh,
tp_mesh: DeviceMesh,
param_dtype: torch.dtype,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
cpu_offload: bool = False,
custom_parallel_plan: Optional[Union[dict, str]] = None,
) -> torch.distributed.fsdp.FSDPModule:
"""Parallelize a NemotronHForCausalLM model across data and tensor parallel dimensions."""
assert not sequence_parallel, (
"Sequence parallelism is not supported for NemotronHForCausalLM"
)
assert custom_parallel_plan is None, (
"Custom parallel plan is not supported for NemotronHForCausalLM"
)

model_tp_plan: dict[str, ParallelStyle] = {
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
}

mlp_tp_plan: dict[str, ParallelStyle] = {
"mixer.up_proj": ColwiseParallel(),
"mixer.down_proj": RowwiseParallel(),
}

layers: torch.nn.ModuleList = model.backbone.layers
parallelize_module(model, tp_mesh, model_tp_plan)

for layer in model.backbone.layers:
if layer.block_type == "mlp":
parallelize_module(layer, tp_mesh, mlp_tp_plan)

if activation_checkpointing:
for i in range(len(layers)):
if layers[i].block_type == "mlp":
layers[i] = checkpoint_wrapper(layers[i])

if layers[i].block_type == "mamba":
layers[i] = checkpoint_wrapper(layers[i])

mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype,
reduce_dtype=torch.float32,
output_dtype=torch.float32,
)

offload_policy = (
CPUOffloadPolicy(pin_memory=False)
if cpu_offload
else torch.distributed.fsdp.OffloadPolicy
)

for layer in layers:
fully_shard(
layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy
)

# do not reshard after forward for root model
# because its parameters will be used in backward immediately
return fully_shard(
model,
mesh=dp_mesh,
mp_policy=mp_policy,
offload_policy=offload_policy,
reshard_after_forward=False,
)


def _parallelize_model(
model: Union[
Qwen2ForCausalLM,
Expand Down Expand Up @@ -394,7 +464,20 @@ def _parallelize_model(
ValueError: If the model type is not supported for parallelization.
"""
model_cls = type(model)
if model_cls == Gemma3ForConditionalGeneration:
if model_cls.__name__ == "NemotronHForCausalLM":
# need to do something special for nm5, since it's harder to shard the mamba layers
# nm5 is not importable, so we check the __name__ attribute
return _parallelize_nm5_h(
model,
dp_mesh,
tp_mesh,
param_dtype,
sequence_parallel,
activation_checkpointing,
cpu_offload,
custom_parallel_plan,
)
elif model_cls == Gemma3ForConditionalGeneration:
layers: torch.nn.ModuleList = model.language_model.layers # type: ignore
num_attention_heads = model.config.text_config.num_attention_heads
num_key_value_heads = model.config.text_config.num_key_value_heads
Expand Down
1 change: 1 addition & 0 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def __init__(
with init_empty_weights():
self.model = model_class.from_config(
model_config,
trust_remote_code=True,
)

if self.model.config.pad_token_id is None:
Expand Down
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,17 @@ automodel = [
# https://github.com/NVIDIA/TransformerEngine/blob/v2.3/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L108
# https://github.com/facebookresearch/xformers/blob/8354497deb2c04c67fbb2e2ad911e86530da0e90/xformers/ops/fmha/flash.py#L76
"flash-attn==2.7.4.post1",
"mamba-ssm",
"causal-conv1d",
]
vllm = [
"vllm==0.10.0",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"flash-attn==2.7.4.post1",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"mamba-ssm",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"causal-conv1d",
]
mcore = [
# also need cudnn (https://developer.nvidia.com/cudnn-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=deb_network)
Expand Down Expand Up @@ -132,6 +138,8 @@ torchvision = [
triton = [
{ index = "pytorch-cu128" },
]
causal-conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d", tag = "v1.5.0.post8" }
mamba-ssm = { git = "https://github.com/state-spaces/mamba.git", rev = "2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4" }

[tool.uv.workspace]
members = [
Expand All @@ -145,7 +153,7 @@ url = "https://download.pytorch.org/whl/cu128"
explicit = true

[tool.uv]
no-build-isolation-package = ["transformer-engine-torch", "transformer-engine", "flash-attn"]
no-build-isolation-package = ["transformer-engine-torch", "transformer-engine", "flash-attn", "mamba-ssm", "causal-conv1d"]
# Always apply the build group since dependencies like TE/mcore/nemo-run require build dependencies
# and this lets us assume they are implicitly installed with a simply `uv sync`. Ideally, we'd
# avoid including these in the default dependency set, but for now it's required.
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/L0_Unit_Tests_Generation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#!/bin/bash
set -xeuo pipefail # Exit immediately if a command exits with a non-zero status

uv run tests/unit/prepare_unit_test_assets.py

cd /opt/nemo-rl
uv run --no-sync bash -x ./tests/run_unit.sh unit/models/generation/ --cov=nemo_rl --cov-report=term-missing --cov-report=json --hf-gated

Expand Down
2 changes: 2 additions & 0 deletions tests/unit/L0_Unit_Tests_Other.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#!/bin/bash
set -xeuo pipefail # Exit immediately if a command exits with a non-zero status

uv run tests/unit/prepare_unit_test_assets.py

cd /opt/nemo-rl
uv run --no-sync bash -x ./tests/run_unit.sh unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --cov=nemo_rl --cov-report=term-missing --cov-report=json --hf-gated

Expand Down
2 changes: 2 additions & 0 deletions tests/unit/L0_Unit_Tests_Policy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#!/bin/bash
set -xeuo pipefail # Exit immediately if a command exits with a non-zero status

uv run tests/unit/prepare_unit_test_assets.py

cd /opt/nemo-rl
uv run --no-sync bash -x ./tests/run_unit.sh unit/models/policy/ --cov=nemo_rl --cov-report=term-missing --cov-report=json --hf-gated

Expand Down
46 changes: 46 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,3 +576,49 @@ def tiny_gemma3_model_path():
tokenizer.save_pretrained(model_path)
del model, tokenizer
yield model_path


def _build_tiny_nemotron5_h_checkpoint(model_path: str) -> None:
import shutil

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

config = AutoConfig.from_pretrained(
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
)
config.hybrid_override_pattern = "M*-"
config.num_hidden_layers = 3
config.intermediate_size = 32
config.hidden_size = 256
config.num_attention_heads = 8
config.mamba_num_heads = 8
config.num_key_value_heads = 8
config.n_groups = 1

model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
)

shutil.rmtree(model_path, ignore_errors=True)
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)


@pytest.fixture(scope="session")
def tiny_nemotron5_h_model_path():
"""Fixture that returns a path to a tiny nemotron model with a dummy tokenizer.

If the asset hasn't been prepared by the prepare script, skip the tests that require it.
"""
model_path = os.path.join(
TEST_ASSETS_DIR, "tiny_nemotron5_h_with_nemotron_tokenizer"
)

config_file = os.path.join(model_path, "config.json")
if not os.path.exists(config_file):
pytest.skip(
"Tiny Nemotron-H test asset not prepared. Run `uv run tests/unit/prepare_unit_test_assets.py` first."
)

yield model_path
7 changes: 7 additions & 0 deletions tests/unit/models/policy/test_dtensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,13 @@ def training_setup(request, two_gpu_virtual_cluster):
("tiny_gemma3_model_path", 1, 1, False, True, True),
("tiny_gemma3_model_path", 1, 1, True, True, True),
# CP doesn't support gemma3 due to spda input has attent_mask != None.
# Nemotron-H doesn't support SP https://github.com/NVIDIA-NeMo/RL/issues/881
# ("tiny_nemotron5_h_model_path", 1, 1, True, True, False),
# ("tiny_nemotron5_h_model_path", 1, 1, True, False, True),
# ("tiny_nemotron5_h_model_path", 1, 1, True, True, True),
("tiny_nemotron5_h_model_path", 1, 1, False, False, False),
("tiny_nemotron5_h_model_path", 1, 1, False, True, True),
# nemotron5_h doesn't support cp
],
indirect=True,
)
Expand Down
98 changes: 98 additions & 0 deletions tests/unit/prepare_unit_test_assets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script exists to help load any unit asset that requires special handling.

The initial reason for this was to help with Nemotron-H which has a requirement
to have mamaba-ssm in the base environment in order to initialize a dummy model. Since
the unit tests should be runable with the base environment (without mamba-ssm),
we use ray.remotes to build the asset here. We do this outside of a fixture
like the other test assets because this one sometimes takes a while to build. This
extra setup time can sometimes cause timeouts in the unit tests if unlucky.
"""

import os

import ray

from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES
from nemo_rl.utils.venvs import create_local_venv

TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
TEST_ASSETS_DIR = os.path.join(TESTS_DIR, "test_assets")


def build_tiny_nemotron5_h_checkpoint(model_path: str) -> None:
import shutil

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

config = AutoConfig.from_pretrained(
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
)
config.hybrid_override_pattern = "M*-"
config.num_hidden_layers = 3
config.intermediate_size = 32
config.hidden_size = 256
config.num_attention_heads = 8
config.mamba_num_heads = 8
config.num_key_value_heads = 8
config.n_groups = 1

model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
)

shutil.rmtree(model_path, ignore_errors=True)
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)
print(f"✓ Built tiny Nemotron-H asset at: {model_path}")


def main() -> None:
os.makedirs(TEST_ASSETS_DIR, exist_ok=True)

target = os.path.join(TEST_ASSETS_DIR, "tiny_nemotron5_h_with_nemotron_tokenizer")

# Create Automodel env venv
automodel_python = create_local_venv(
py_executable=PY_EXECUTABLES.AUTOMODEL, venv_name="automodel_env"
)

############################################################################
# Add other remote calls here
############################################################################
# Submit as list of remote calls and wait individually
remote_calls = [
ray.remote(build_tiny_nemotron5_h_checkpoint)
.options(
num_gpus=0.01, # tiny reservation to satisfy CUDA-inspecting deps
runtime_env={"py_executable": automodel_python},
name="build-nemotron5h",
)
.remote(target)
]

for obj_ref in remote_calls:
ray.get(obj_ref)


if __name__ == "__main__":
if not ray.is_initialized():
ray.init(ignore_reinit_error=True, include_dashboard=False)
try:
main()
finally:
ray.shutdown()
Loading