diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 6a58d930d5..d86a82e738 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -15,7 +15,7 @@ import os import time import warnings -from typing import Any, Optional, TypeVar +from typing import Any, Optional, TypeVar, get_args import torch from megatron.bridge import AutoBridge @@ -33,6 +33,7 @@ DistributedDataParallelConfig, LoggerConfig, OptimizerConfig, + RNGConfig, SchedulerConfig, TokenizerConfig, TrainingConfig, @@ -73,7 +74,7 @@ from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.models.megatron.community_import import import_model_from_hf_name from nemo_rl.models.megatron.config import ModelAndOptimizerState, RuntimeConfig -from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy import CudaGraphScope, PolicyConfig from nemo_rl.models.policy.utils import ( configure_dynamo_cache, get_megatron_checkpoint_dir, @@ -481,6 +482,28 @@ def _apply_performance_config(model_cfg: Any, config: PolicyConfig) -> None: f"Unsupported {type(attention_backend)=}, expected str or int" ) + # CUDA Graph configuration + if "enable_cuda_graph" in config["megatron_cfg"]: + model_cfg.enable_cuda_graph = config["megatron_cfg"]["enable_cuda_graph"] + if "cuda_graph_scope" in config["megatron_cfg"]: + scope = config["megatron_cfg"]["cuda_graph_scope"] + valid_scopes = get_args(CudaGraphScope) + if scope not in valid_scopes: + raise ValueError( + f"Invalid cuda_graph_scope '{scope}'. " + f"Valid options are: {valid_scopes}" + ) + model_cfg.cuda_graph_scope = scope + if not model_cfg.enable_cuda_graph: + warnings.warn( + "cuda_graph_scope is configured but enable_cuda_graph is False. " + "The cuda_graph_scope setting will have no effect." + ) + if model_cfg.enable_cuda_graph: + model_cfg.use_te_rng_tracker = True + else: + model_cfg.use_te_rng_tracker = False + # FP8 configuration fp8_cfg = config["megatron_cfg"].get("fp8_cfg", None) if fp8_cfg is not None and fp8_cfg.get("enabled", False): @@ -607,10 +630,16 @@ def _create_megatron_config( dtype: torch.dtype, ) -> ConfigContainer: """Create the final Megatron configuration container.""" + # Create RNG config with CUDA graph support + rng_config = RNGConfig( + te_rng_tracker=config["megatron_cfg"].get("enable_cuda_graph", False), + ) + return ConfigContainer( model=model_cfg, checkpoint=checkpoint_config, logger=LoggerConfig(logging_level=0), + rng=rng_config, train=TrainingConfig( micro_batch_size=1, # ignored global_batch_size=config["train_global_batch_size"], # ignored diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 3636e5ac64..d4a8107d7a 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Literal, NotRequired, TypedDict, Union +from typing import Any, Literal, NotRequired, TypedDict from nemo_rl.models.generation.interfaces import GenerationConfig +CudaGraphScope = Literal[ + "full_iteration", "attn", "mlp", "moe", "moe_router", "moe_preprocess", "mamba" +] + class LoRAConfigDisabled(TypedDict): enabled: Literal[False] @@ -224,6 +228,8 @@ class MegatronConfig(TypedDict): # Number of tokens per chunk when computing the fused linear CE loss. # Smaller values reduce peak memory further but may decrease throughput. linear_ce_fusion_chunk_size: NotRequired[int] + enable_cuda_graph: bool + cuda_graph_scope: CudaGraphScope class TokenizerConfig(TypedDict): diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index eff97b215c..db108e7cac 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -65,6 +65,8 @@ def create_megatron_test_config( converter_type: str = "LlamaForCausalLM", logprob_chunk_size: Optional[int] = None, defer_fp32_logits: Optional[bool] = None, + enable_cuda_graph: Optional[bool] = None, + cuda_graph_scope: Optional[str] = None, ) -> PolicyConfig: """Create a test config for Megatron policy worker.""" return { @@ -138,6 +140,8 @@ def create_megatron_test_config( "moe_token_dispatcher_type": "alltoall", "moe_shared_expert_overlap": False, "defer_fp32_logits": defer_fp32_logits, + **({"enable_cuda_graph": enable_cuda_graph} if enable_cuda_graph is not None else {}), + **({"cuda_graph_scope": cuda_graph_scope} if cuda_graph_scope is not None else {}), "use_linear_ce_fusion_loss": False, "linear_ce_fusion_chunk_size": 256, "train_iters": 100, # Required for Megatron training @@ -2686,3 +2690,37 @@ def test_megatron_policy_flops_range_check(tiny_llama_model_path): finally: policy.shutdown() cluster.shutdown() + + +def test_cuda_graph_config_parsing(): + """Test CUDA graph configuration options are properly included in test configs.""" + # Test config without CUDA graph options (default) + config_default = create_megatron_test_config( + model_name="test-model", + ) + assert "enable_cuda_graph" not in config_default["megatron_cfg"] + assert "cuda_graph_scope" not in config_default["megatron_cfg"] + + # Test config with CUDA graph enabled + config_enabled = create_megatron_test_config( + model_name="test-model", + enable_cuda_graph=True, + ) + assert config_enabled["megatron_cfg"]["enable_cuda_graph"] is True + assert "cuda_graph_scope" not in config_enabled["megatron_cfg"] + + # Test config with CUDA graph enabled and scope set + config_with_scope = create_megatron_test_config( + model_name="test-model", + enable_cuda_graph=True, + cuda_graph_scope="full_iteration", + ) + assert config_with_scope["megatron_cfg"]["enable_cuda_graph"] is True + assert config_with_scope["megatron_cfg"]["cuda_graph_scope"] == "full_iteration" + + # Test config with CUDA graph disabled + config_disabled = create_megatron_test_config( + model_name="test-model", + enable_cuda_graph=False, + ) + assert config_disabled["megatron_cfg"]["enable_cuda_graph"] is False