Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nemo_rl/distributed/virtual_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class PY_EXECUTABLES:
# Use NeMo-RL direct dependencies.
BASE = "uv run --locked"

# Use NeMo-RL direct dependencies.
AUTOMODEL = "uv run --locked --extra automodel"

# Use NeMo-RL direct dependencies and vllm.
VLLM = "uv run --locked --extra vllm"

Expand Down
85 changes: 84 additions & 1 deletion nemo_rl/models/dtensor/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,76 @@ def get_hf_tp_plan(model: PreTrainedModel):
return hf_tp_plan


def _parallelize_nm5_h(
model,
dp_mesh: DeviceMesh,
tp_mesh: DeviceMesh,
param_dtype: torch.dtype,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
cpu_offload: bool = False,
custom_parallel_plan: Optional[Union[dict, str]] = None,
) -> torch.distributed.fsdp.FSDPModule:
"""Parallelize a NemotronHForCausalLM model across data and tensor parallel dimensions."""
assert not sequence_parallel, (
"Sequence parallelism is not supported for NemotronHForCausalLM"
)
assert custom_parallel_plan is None, (
"Custom parallel plan is not supported for NemotronHForCausalLM"
)

model_tp_plan: dict[str, ParallelStyle] = {
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
}

mlp_tp_plan: dict[str, ParallelStyle] = {
"mixer.up_proj": ColwiseParallel(),
"mixer.down_proj": RowwiseParallel(),
}

layers: torch.nn.ModuleList = model.backbone.layers
parallelize_module(model, tp_mesh, model_tp_plan)

for layer in model.backbone.layers:
if layer.block_type == "mlp":
parallelize_module(layer, tp_mesh, mlp_tp_plan)

if activation_checkpointing:
for i in range(len(layers)):
if layers[i].block_type == "mlp":
layers[i] = checkpoint_wrapper(layers[i])

if layers[i].block_type == "mamba":
layers[i] = checkpoint_wrapper(layers[i])

mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype,
reduce_dtype=torch.float32,
output_dtype=torch.float32,
)

offload_policy = (
CPUOffloadPolicy(pin_memory=False)
if cpu_offload
else torch.distributed.fsdp.OffloadPolicy
)

for layer in layers:
fully_shard(
layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy
)

# do not reshard after forward for root model
# because its parameters will be used in backward immediately
return fully_shard(
model,
mesh=dp_mesh,
mp_policy=mp_policy,
offload_policy=offload_policy,
reshard_after_forward=False,
)


def _parallelize_model(
model: Union[
Qwen2ForCausalLM,
Expand Down Expand Up @@ -394,7 +464,20 @@ def _parallelize_model(
ValueError: If the model type is not supported for parallelization.
"""
model_cls = type(model)
if model_cls == Gemma3ForConditionalGeneration:
if model_cls.__name__ == "NemotronHForCausalLM":
# need to do something special for nm5, since it's harder to shard the mamba layers
# nm5 is not importable, so we check the __name__ attribute
return _parallelize_nm5_h(
model,
dp_mesh,
tp_mesh,
param_dtype,
sequence_parallel,
activation_checkpointing,
cpu_offload,
custom_parallel_plan,
)
elif model_cls == Gemma3ForConditionalGeneration:
layers: torch.nn.ModuleList = model.language_model.layers # type: ignore
num_attention_heads = model.config.text_config.num_attention_heads
num_key_value_heads = model.config.text_config.num_key_value_heads
Expand Down
1 change: 1 addition & 0 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def __init__(
with init_empty_weights():
self.model = model_class.from_config(
model_config,
trust_remote_code=True,
)

if self.model.config.pad_token_id is None:
Expand Down
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,17 @@ automodel = [
# https://github.com/NVIDIA/TransformerEngine/blob/v2.3/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L108
# https://github.com/facebookresearch/xformers/blob/8354497deb2c04c67fbb2e2ad911e86530da0e90/xformers/ops/fmha/flash.py#L76
"flash-attn==2.7.4.post1",
"mamba-ssm",
"causal-conv1d",
]
vllm = [
"vllm==0.10.0",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"flash-attn==2.7.4.post1",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"mamba-ssm",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"causal-conv1d",
]
mcore = [
# also need cudnn (https://developer.nvidia.com/cudnn-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=deb_network)
Expand Down Expand Up @@ -132,6 +138,8 @@ torchvision = [
triton = [
{ index = "pytorch-cu128" },
]
causal-conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d", tag = "v1.5.0.post8" }
mamba-ssm = { git = "https://github.com/state-spaces/mamba.git", rev = "2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4" }

[tool.uv.workspace]
members = [
Expand All @@ -145,7 +153,7 @@ url = "https://download.pytorch.org/whl/cu128"
explicit = true

[tool.uv]
no-build-isolation-package = ["transformer-engine-torch", "transformer-engine", "flash-attn"]
no-build-isolation-package = ["transformer-engine-torch", "transformer-engine", "flash-attn", "mamba-ssm", "causal-conv1d"]
# Always apply the build group since dependencies like TE/mcore/nemo-run require build dependencies
# and this lets us assume they are implicitly installed with a simply `uv sync`. Ideally, we'd
# avoid including these in the default dependency set, but for now it's required.
Expand Down
54 changes: 54 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,3 +576,57 @@ def tiny_gemma3_model_path():
tokenizer.save_pretrained(model_path)
del model, tokenizer
yield model_path


def _build_tiny_nemotron5_h_checkpoint(model_path: str) -> None:
import shutil

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

config = AutoConfig.from_pretrained(
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
)
config.hybrid_override_pattern = "M*-"
config.num_hidden_layers = 3
config.intermediate_size = 32
config.hidden_size = 256
config.num_attention_heads = 8
config.mamba_num_heads = 8
config.num_key_value_heads = 8
config.n_groups = 1

model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
)

shutil.rmtree(model_path, ignore_errors=True)
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)


@pytest.fixture(scope="session")
def tiny_nemotron5_h_model_path():
"""Fixture that returns a path to a tiny nemotron model with a dummy tokenizer."""
model_path = os.path.join(
TEST_ASSETS_DIR, "tiny_nemotron5_h_with_nemotron_tokenizer"
)

# Run the builder inside the Automodel environment using a dedicated venv pythondd
from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES

# Create the ray-remote wrapped function for reuse
build_tiny_nemotron5_h_checkpoint_remote = ray.remote(
_build_tiny_nemotron5_h_checkpoint
)

ray.get(
build_tiny_nemotron5_h_checkpoint_remote.options(
# Need a GPU to even import mamba-ssm (just claim a super small number to not error)
num_gpus=0.01,
runtime_env={"py_executable": PY_EXECUTABLES.AUTOMODEL},
name="build-tiny-nemotron5-h",
).remote(model_path)
)

yield model_path
6 changes: 6 additions & 0 deletions tests/unit/models/policy/test_dtensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,12 @@ def training_setup(request, two_gpu_virtual_cluster):
("tiny_gemma3_model_path", 1, 1, False, True, True),
("tiny_gemma3_model_path", 1, 1, True, True, True),
# CP doesn't support gemma3 due to spda input has attent_mask != None.
# Nemotron-H doesn't support SP https://github.com/NVIDIA-NeMo/RL/issues/881
# ("tiny_nemotron5_h_model_path", 1, 1, True, True, False),
# ("tiny_nemotron5_h_model_path", 1, 1, True, False, True),
("tiny_nemotron5_h_model_path", 1, 1, False, True, True),
("tiny_nemotron5_h_model_path", 1, 1, True, True, True),
# nemotron5_h doesn't support cp
],
indirect=True,
)
Expand Down
33 changes: 33 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.