Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions nemo_rl/models/megatron/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import time
import warnings
from typing import Any, Optional, TypeVar
from typing import Any, Optional, TypeVar, get_args

import torch
from megatron.bridge import AutoBridge
Expand All @@ -33,6 +33,7 @@
DistributedDataParallelConfig,
LoggerConfig,
OptimizerConfig,
RNGConfig,
SchedulerConfig,
TokenizerConfig,
TrainingConfig,
Expand Down Expand Up @@ -73,7 +74,7 @@
from nemo_rl.distributed.named_sharding import NamedSharding
from nemo_rl.models.megatron.community_import import import_model_from_hf_name
from nemo_rl.models.megatron.config import ModelAndOptimizerState, RuntimeConfig
from nemo_rl.models.policy import PolicyConfig
from nemo_rl.models.policy import CudaGraphScope, PolicyConfig
from nemo_rl.models.policy.utils import (
configure_dynamo_cache,
get_megatron_checkpoint_dir,
Expand Down Expand Up @@ -481,6 +482,28 @@ def _apply_performance_config(model_cfg: Any, config: PolicyConfig) -> None:
f"Unsupported {type(attention_backend)=}, expected str or int"
)

# CUDA Graph configuration
if "enable_cuda_graph" in config["megatron_cfg"]:
model_cfg.enable_cuda_graph = config["megatron_cfg"]["enable_cuda_graph"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add these configs to

class MegatronConfig(TypedDict):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added Literal definition using reference from Megatron-Bridge

if "cuda_graph_scope" in config["megatron_cfg"]:
scope = config["megatron_cfg"]["cuda_graph_scope"]
valid_scopes = get_args(CudaGraphScope)
if scope not in valid_scopes:
raise ValueError(
f"Invalid cuda_graph_scope '{scope}'. "
f"Valid options are: {valid_scopes}"
)
model_cfg.cuda_graph_scope = scope
if not model_cfg.enable_cuda_graph:
warnings.warn(
"cuda_graph_scope is configured but enable_cuda_graph is False. "
"The cuda_graph_scope setting will have no effect."
)
if model_cfg.enable_cuda_graph:
model_cfg.use_te_rng_tracker = True
else:
model_cfg.use_te_rng_tracker = False

# FP8 configuration
fp8_cfg = config["megatron_cfg"].get("fp8_cfg", None)
if fp8_cfg is not None and fp8_cfg.get("enabled", False):
Expand Down Expand Up @@ -607,10 +630,16 @@ def _create_megatron_config(
dtype: torch.dtype,
) -> ConfigContainer:
"""Create the final Megatron configuration container."""
# Create RNG config with CUDA graph support
rng_config = RNGConfig(
te_rng_tracker=config["megatron_cfg"].get("enable_cuda_graph", False),
)

return ConfigContainer(
model=model_cfg,
checkpoint=checkpoint_config,
logger=LoggerConfig(logging_level=0),
rng=rng_config,
train=TrainingConfig(
micro_batch_size=1, # ignored
global_batch_size=config["train_global_batch_size"], # ignored
Expand Down
8 changes: 7 additions & 1 deletion nemo_rl/models/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Literal, NotRequired, TypedDict, Union
from typing import Any, Literal, NotRequired, TypedDict

from nemo_rl.models.generation.interfaces import GenerationConfig

CudaGraphScope = Literal[
"full_iteration", "attn", "mlp", "moe", "moe_router", "moe_preprocess", "mamba"
]


class LoRAConfigDisabled(TypedDict):
enabled: Literal[False]
Expand Down Expand Up @@ -224,6 +228,8 @@ class MegatronConfig(TypedDict):
# Number of tokens per chunk when computing the fused linear CE loss.
# Smaller values reduce peak memory further but may decrease throughput.
linear_ce_fusion_chunk_size: NotRequired[int]
enable_cuda_graph: bool
cuda_graph_scope: CudaGraphScope


class TokenizerConfig(TypedDict):
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/models/policy/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def create_megatron_test_config(
converter_type: str = "LlamaForCausalLM",
logprob_chunk_size: Optional[int] = None,
defer_fp32_logits: Optional[bool] = None,
enable_cuda_graph: Optional[bool] = None,
cuda_graph_scope: Optional[str] = None,
) -> PolicyConfig:
"""Create a test config for Megatron policy worker."""
return {
Expand Down Expand Up @@ -138,6 +140,8 @@ def create_megatron_test_config(
"moe_token_dispatcher_type": "alltoall",
"moe_shared_expert_overlap": False,
"defer_fp32_logits": defer_fp32_logits,
**({"enable_cuda_graph": enable_cuda_graph} if enable_cuda_graph is not None else {}),
**({"cuda_graph_scope": cuda_graph_scope} if cuda_graph_scope is not None else {}),
"use_linear_ce_fusion_loss": False,
"linear_ce_fusion_chunk_size": 256,
"train_iters": 100, # Required for Megatron training
Expand Down Expand Up @@ -2686,3 +2690,37 @@ def test_megatron_policy_flops_range_check(tiny_llama_model_path):
finally:
policy.shutdown()
cluster.shutdown()


def test_cuda_graph_config_parsing():
"""Test CUDA graph configuration options are properly included in test configs."""
# Test config without CUDA graph options (default)
config_default = create_megatron_test_config(
model_name="test-model",
)
assert "enable_cuda_graph" not in config_default["megatron_cfg"]
assert "cuda_graph_scope" not in config_default["megatron_cfg"]

# Test config with CUDA graph enabled
config_enabled = create_megatron_test_config(
model_name="test-model",
enable_cuda_graph=True,
)
assert config_enabled["megatron_cfg"]["enable_cuda_graph"] is True
assert "cuda_graph_scope" not in config_enabled["megatron_cfg"]

# Test config with CUDA graph enabled and scope set
config_with_scope = create_megatron_test_config(
model_name="test-model",
enable_cuda_graph=True,
cuda_graph_scope="full_iteration",
)
assert config_with_scope["megatron_cfg"]["enable_cuda_graph"] is True
assert config_with_scope["megatron_cfg"]["cuda_graph_scope"] == "full_iteration"

# Test config with CUDA graph disabled
config_disabled = create_megatron_test_config(
model_name="test-model",
enable_cuda_graph=False,
)
assert config_disabled["megatron_cfg"]["enable_cuda_graph"] is False
Loading