diff --git a/nemo_rl/models/megatron/__init__.py b/nemo_rl/models/megatron/__init__.py index e69de29bb2..4fc25d0d3c 100644 --- a/nemo_rl/models/megatron/__init__.py +++ b/nemo_rl/models/megatron/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_rl/models/megatron/config.py b/nemo_rl/models/megatron/config.py new file mode 100644 index 0000000000..7e249affcd --- /dev/null +++ b/nemo_rl/models/megatron/config.py @@ -0,0 +1,76 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, NamedTuple, Optional, TypedDict + +import torch +from megatron.bridge.training.config import ConfigContainer +from megatron.bridge.training.state import GlobalState +from megatron.core.optimizer import MegatronOptimizer +from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler +from megatron.core.transformer import MegatronModule + + +class MegatronGenerationConfig(TypedDict): + # Total GPU memory (in GB) allocated for KV cache buffers + buffer_size_gb: int + # Fraction of buffer reserved for guaranteed active requests + buffer_guaranteed_fraction: float + # Number of CUDA graphs to pre-compile for different batch sizes + num_cuda_graphs: int + # Size of each KV cache block in tokens (affects memory granularity) + block_size_tokens: int + # Enable CUDA graphs for prefill/context processing + use_cuda_graphs_for_non_decode_steps: bool + # Split long prefills into chunks for better memory management + enable_chunked_prefill: bool + # Unified memory usage level (0=disabled, higher values enable more aggressive paging) + unified_memory_level: int + # Maximum number of tokens to use in a single step. Analogous to vllm's max_num_batched_tokens. + # Can cause OOM if set too high so should be tuned with buffer_size_gb if OOMing. If set too + # low, then will only do 512 tokens at a time, which can be slow. + max_tokens: int + + +## returned from validate_and_set_config +class RuntimeConfig(NamedTuple): + """Runtime configuration for model training and inference. + + This contains all validated runtime settings needed for model initialization, + parallelization, and training. + """ + + megatron_cfg: ConfigContainer + model_cfg: Any + dtype: torch.dtype + optimizer_cpu_offload: bool + offload_optimizer_for_logprob: bool + is_generation_colocated: Optional[bool] + final_padded_vocab_size: int + + +## returned from setup_model_and_optimizer +class ModelAndOptimizerState(NamedTuple): + """Container for model and optimizer state. + + This named tuple holds all model-related state including the model itself, + optimizer, scheduler, and metadata about the model type and configuration. + """ + + state: GlobalState + model: MegatronModule + optimizer: MegatronOptimizer + scheduler: OptimizerParamScheduler + checkpointing_context: dict[str, Any] + param_sync_func: Optional[Callable] diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py new file mode 100644 index 0000000000..24bfdb0605 --- /dev/null +++ b/nemo_rl/models/megatron/setup.py @@ -0,0 +1,992 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import warnings +from typing import Any, Optional, TypeVar + +import torch +from megatron.bridge import AutoBridge +from megatron.bridge.models.model_provider import get_model +from megatron.bridge.peft.lora import LoRA +from megatron.bridge.training import fault_tolerance +from megatron.bridge.training.checkpointing import ( + checkpoint_exists, + init_checkpointing_context, + load_checkpoint, +) +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + DistributedDataParallelConfig, + LoggerConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.initialize import ( + initialize_megatron, + set_jit_fusion_options, +) +from megatron.bridge.training.optim import setup_optimizer +from megatron.bridge.training.setup import ( + _create_peft_pre_wrap_hook, + _update_model_config_funcs, +) +from megatron.bridge.training.state import GlobalState +from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer +from megatron.bridge.utils.instantiate_utils import InstantiationMode +from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size +from megatron.core import parallel_state +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.module import Float16Module +from megatron.core.transformer.transformer_config import TransformerConfig +from transformers import PreTrainedTokenizerBase + +try: + from megatron.core.distributed import ( + TorchFullyShardedDataParallel as torch_FSDP, # noqa: F401 unused-import + ) + + HAVE_FSDP2 = True +except ImportError: + HAVE_FSDP2 = False + +from nemo_rl.distributed.named_sharding import NamedSharding +from nemo_rl.models.megatron.community_import import import_model_from_hf_name +from nemo_rl.models.megatron.config import ModelAndOptimizerState, RuntimeConfig +from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy.utils import ( + configure_dynamo_cache, + get_megatron_checkpoint_dir, +) + +TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) + + +def destroy_parallel_state(): + """Safely destroy parallel state and reset async call tracking. + + This function is called during initialization to clean up temporary distributed + state from model import operations. Resetting async call tracking ensures that + when the main Megatron distributed context is created, all ranks start with + consistent call_idx values for async checkpointing. + """ + if torch.distributed.is_initialized(): + try: + torch.distributed.barrier() + torch.distributed.destroy_process_group() + except: + pass # Ignore errors if already destroyed + if hasattr(parallel_state, "destroy_model_parallel"): + try: + parallel_state.destroy_model_parallel() + except: + pass # Ignore errors if already destroyed + + # Reset async calls queue to prevent call_idx mismatches after distributed context recreation + try: + import nemo.tron.utils.async_utils as nemo_async_utils + from megatron.core.dist_checkpointing.strategies.async_utils import ( + AsyncCallsQueue, + ) + + # Clean up any existing async callers first + old_call_idx = getattr(nemo_async_utils._async_calls_queue, "call_idx", None) + num_unfinalized = ( + nemo_async_utils._async_calls_queue.get_num_unfinalized_calls() + ) + if num_unfinalized > 0: + print( + f"[WARNING] Resetting async calls queue with {num_unfinalized} unfinalized calls" + ) + try: + nemo_async_utils._async_calls_queue.close() + except: + pass # Ignore errors during cleanup + # Reset the global async calls queue by creating a new instance + nemo_async_utils._async_calls_queue = AsyncCallsQueue() + except ImportError: + pass + + # Also reset the Megatron async calls queue if it exists + try: + import megatron.training.async_utils as megatron_async_utils + from megatron.core.dist_checkpointing.strategies.async_utils import ( + AsyncCallsQueue, + ) + + # Clean up any existing async callers first + old_call_idx = getattr( + megatron_async_utils._async_calls_queue, "call_idx", None + ) + num_unfinalized = ( + megatron_async_utils._async_calls_queue.get_num_unfinalized_calls() + ) + if num_unfinalized > 0: + print( + f"[WARNING] Resetting Megatron async calls queue with {num_unfinalized} unfinalized calls" + ) + try: + megatron_async_utils._async_calls_queue.close() + except: + pass # Ignore errors during cleanup + # Reset the Megatron global async calls queue as well + megatron_async_utils._async_calls_queue = AsyncCallsQueue() + print( + f"[DEBUG] Reset Megatron async calls queue (old call_idx: {old_call_idx})" + ) + except ImportError: + pass + + # Reset the third global async_calls instance in base strategy module + try: + import megatron.core.dist_checkpointing.strategies.base as base_strategy + from megatron.core.dist_checkpointing.strategies.async_utils import ( + AsyncCallsQueue, + ) + + # Clean up and reset the global async_calls in base strategy + old_call_idx = getattr(base_strategy.async_calls, "call_idx", None) + num_unfinalized = base_strategy.async_calls.get_num_unfinalized_calls() + if num_unfinalized > 0: + print( + f"[WARNING] Resetting base strategy async_calls with {num_unfinalized} unfinalized calls" + ) + try: + base_strategy.async_calls.close() + except: + pass + base_strategy.async_calls = AsyncCallsQueue() + print(f"[DEBUG] Reset base strategy async_calls (old call_idx: {old_call_idx})") + except ImportError: + pass + + +def setup_distributed() -> None: + """Handle NCCL settings, dtype mapping, and basic config setup.""" + # Disable dynamo autotune_local_cache to avoid crash when there's already a cache + # with different order of node_bundles + configure_dynamo_cache() + # Ensure clean slate before import + destroy_parallel_state() + # Need to initialize the process group before calling into Megatron-Bridge, otherwise Megatron-Bridge will try to set an incorrect device + torch.distributed.init_process_group("nccl") + + +def validate_and_set_config( + config, + rank, + hf_model_name, + pretrained_path, + weights_path, + tokenizer, +): + # Handle generation colocation + is_generation_colocated = None + if "generation" in config and config["generation"] is not None: + is_generation_colocated = config["generation"]["colocated"]["enabled"] + + # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. + # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. + if not is_generation_colocated: + os.environ["NCCL_CUMEM_ENABLE"] = "1" + + # Setup data types + dtype_map = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, + } + dtype = dtype_map[config["precision"]] + + # Optimizer configuration + optimizer_cpu_offload = config["megatron_cfg"]["optimizer"]["optimizer_cpu_offload"] + offload_optimizer_for_logprob = config["offload_optimizer_for_logprob"] + + # Reward models are not yet supported with Megatron. + if "reward_model_cfg" in config and config["reward_model_cfg"]["enabled"]: + raise NotImplementedError( + "Reward models are not yet supported with the Megatron backend, this issue is " + "tracked in https://github.com/NVIDIA-NeMo/RL/issues/720" + ) + + megatron_cfg, model_cfg = setup_model_config( + config, rank, dtype, hf_model_name, pretrained_path, weights_path + ) + + final_padded_vocab_size = calculate_padded_vocab_size( + megatron_cfg.model.vocab_size, + megatron_cfg.model.make_vocab_size_divisible_by, + config["megatron_cfg"]["tensor_model_parallel_size"], + ) + + return RuntimeConfig( + megatron_cfg, + model_cfg, + dtype, + optimizer_cpu_offload, + offload_optimizer_for_logprob, + is_generation_colocated, + final_padded_vocab_size, + ) + + +def validate_model_paths(config: PolicyConfig) -> tuple[str, str, bool]: + """Validate and setup model paths.""" + # cfg["model_name"] is allowed to be either an HF model name or a path to an HF checkpoint + hf_model_name = config["model_name"] + + # Check if the checkpoint already exists + hf_model_subdir = hf_model_name + if os.path.exists(hf_model_name): + hf_model_subdir = f"model_{hf_model_subdir.replace('/', '_')}" + + pretrained_path = f"{get_megatron_checkpoint_dir()}/{hf_model_subdir}" + pt_checkpoint_exists = os.path.exists(pretrained_path) and os.path.exists( + os.path.join(pretrained_path, "iter_0000000") + ) + + return hf_model_name, pretrained_path, pt_checkpoint_exists + + +def setup_model_config( + config: PolicyConfig, + rank, + dtype, + hf_model_name: str, + pretrained_path: str, + weights_path: Optional[str] = None, +) -> tuple[ConfigContainer, Any]: + """Handle all the model configuration logic.""" + # Load pretrained run config + pretrained_run_config = os.path.join( + pretrained_path, "iter_0000000/run_config.yaml" + ) + + if not os.path.exists(pretrained_run_config): + raise FileNotFoundError( + f"Pretrained run config not found at {pretrained_run_config} on rank={rank}. " + "This usually means that the one-time HF->mcore conversion on rank=0 saved to a directory " + "not being mounted on this node. Please check" + ) + + try: + cfg_from_pretrained = ConfigContainer.from_yaml( + pretrained_run_config, mode=InstantiationMode.STRICT + ) + except Exception as e: + # Add helpful context as a note to the exception + e.add_note( + f"\n{'=' * 80}\n" + f"NOTE: A common cause of this error is when the HF->mcore converted checkpoint is\n" + f"created with an older version of megatron-bridge.\n" + f"If this checkpoint is old or was generated by a different code version,\n" + f"try deleting it and rerunning the code.\n" + f"The checkpoint will be automatically regenerated with the current version.\n\n" + f"Checkpoint location: {pretrained_path}\n" + f"{'=' * 80}" + ) + raise + + model_cfg = cfg_from_pretrained.model + cfg_from_pretrained.logger = LoggerConfig() + + # Apply parallelism settings + _apply_parallelism_config(model_cfg, config) + + # Apply MoE settings + _apply_moe_config(model_cfg, config) + + # Apply precision settings + _apply_precision_config(model_cfg, config, dtype) + + # Apply performance settings + _apply_performance_config(model_cfg, config) + + # Validate optimizer configuration + _validate_optimizer_config(config) + + # Optional layernorm epsilon + if "layernorm_epsilon" in config["megatron_cfg"]: + model_cfg.layernorm_epsilon = config["megatron_cfg"]["layernorm_epsilon"] + + # Validate chunking configuration + _validate_chunking_config(config) + + # Create checkpoint configs + checkpoint_config = _create_checkpoint_config(pretrained_path, weights_path) + + # Validate training configuration + _validate_training_config(config, model_cfg) + + # Create final megatron config + megatron_cfg = _create_megatron_config( + model_cfg, checkpoint_config, config, hf_model_name, dtype + ) + + _validate_dtype_config(dtype, megatron_cfg.model, megatron_cfg.optimizer) + + return megatron_cfg, model_cfg + + +def _apply_parallelism_config(model_cfg: Any, config: PolicyConfig) -> None: + """Apply tensor/pipeline/context parallelism configuration.""" + model_cfg.tensor_model_parallel_size = config["megatron_cfg"][ + "tensor_model_parallel_size" + ] + model_cfg.pipeline_model_parallel_size = config["megatron_cfg"][ + "pipeline_model_parallel_size" + ] + model_cfg.num_layers_in_first_pipeline_stage = config["megatron_cfg"][ + "num_layers_in_first_pipeline_stage" + ] + model_cfg.num_layers_in_last_pipeline_stage = config["megatron_cfg"][ + "num_layers_in_last_pipeline_stage" + ] + model_cfg.sequence_parallel = config["megatron_cfg"]["sequence_parallel"] + model_cfg.context_parallel_size = config["megatron_cfg"]["context_parallel_size"] + + if model_cfg.context_parallel_size > 1: + assert config["sequence_packing"]["enabled"], ( + "Sequence Packing must be enabled to use Context Parallelism with MCore" + ) + + +def _apply_moe_config(model_cfg: Any, config: PolicyConfig) -> None: + """Apply Mixture of Experts configuration.""" + model_cfg.expert_tensor_parallel_size = config["megatron_cfg"][ + "expert_tensor_parallel_size" + ] + model_cfg.expert_model_parallel_size = config["megatron_cfg"][ + "expert_model_parallel_size" + ] + + # MoE stability settings + + # Setting moe_router_dtype to higher precision (e.g. fp64) can improve numerical stability, + # especially when using many experts. + model_cfg.moe_router_dtype = config["megatron_cfg"]["moe_router_dtype"] + + # The below two configs (and "freeze_moe_router") are used to stabilize moe training + # by preventing updates to the moe router. We found that this is helpful in reducing + # logprob error during training. + + # Set this to "none" to disable load balancing loss. + model_cfg.moe_router_load_balancing_type = config["megatron_cfg"][ + "moe_router_load_balancing_type" + ] + # Set this to 0.0 to disable updates to the moe router expert bias + model_cfg.moe_router_bias_update_rate = config["megatron_cfg"][ + "moe_router_bias_update_rate" + ] + + model_cfg.moe_enable_deepep = config["megatron_cfg"]["moe_enable_deepep"] + model_cfg.moe_token_dispatcher_type = config["megatron_cfg"][ + "moe_token_dispatcher_type" + ] + model_cfg.moe_shared_expert_overlap = config["megatron_cfg"][ + "moe_shared_expert_overlap" + ] + + model_cfg.moe_permute_fusion = config["megatron_cfg"]["moe_permute_fusion"] + + +def _apply_precision_config( + model_cfg: Any, config: PolicyConfig, dtype: torch.dtype +) -> None: + """Apply precision and dtype configuration.""" + model_cfg.bf16 = dtype == torch.bfloat16 + model_cfg.fp16 = dtype == torch.float16 + + if model_cfg.fp16: + assert not model_cfg.bf16, "fp16 and bf16 cannot be used together" + model_cfg.params_dtype = torch.float16 + elif model_cfg.bf16: + assert not model_cfg.fp16, "fp16 and bf16 cannot be used together" + model_cfg.params_dtype = torch.bfloat16 + else: + model_cfg.params_dtype = torch.float32 + + dtype_map = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, + } + model_cfg.pipeline_dtype = dtype_map[config["megatron_cfg"]["pipeline_dtype"]] + + +def _apply_performance_config(model_cfg: Any, config: PolicyConfig) -> None: + """Apply performance optimization configuration.""" + model_cfg.parallel_output = True + + # Activation checkpointing + if config["megatron_cfg"]["activation_checkpointing"]: + model_cfg.recompute_granularity = "full" + model_cfg.recompute_method = "uniform" + model_cfg.recompute_num_layers = 1 + + # Activation function validation + if not model_cfg.gated_linear_unit: + assert model_cfg.activation_func is not None, ( + "activation_func must be set if not using gated_linear_unit. This likely " + "indicates an issue in configuration conversion (e.g. activation func was " + "a lambda and couldn't be serialized). This is based on this check " + "https://github.com/NVIDIA/Megatron-LM/blob/1ab876ddc4c1893c76f26d775226a8d1dcdfb3d2/megatron/core/transformer/mlp.py#L174." + ) + + # Fusion settings + model_cfg.apply_rope_fusion = config["megatron_cfg"]["apply_rope_fusion"] + model_cfg.bias_activation_fusion = config["megatron_cfg"]["bias_activation_fusion"] + + # FP8 configuration + fp8_cfg = config["megatron_cfg"].get("fp8_cfg", None) + if fp8_cfg is not None and fp8_cfg.get("enabled", False): + try: + model_cfg.fp8 = fp8_cfg["fp8"] + model_cfg.fp8_recipe = fp8_cfg["fp8_recipe"] + model_cfg.fp8_param = fp8_cfg["fp8_param"] + except KeyError as e: + raise KeyError(f"Missing key in fp8_cfg: {e}") + + if model_cfg.fp8_param: + warnings.warn( + "Setting fp8_param=True sometimes causes NaN token_mult_prob_error, please use with caution. " + "Refer to https://github.com/NVIDIA-NeMo/RL/issues/1164 for latest updates with this issue." + ) + + +def _validate_optimizer_config(config: PolicyConfig) -> None: + """Validate optimizer configuration.""" + optimizer_cpu_offload = config["megatron_cfg"]["optimizer"]["optimizer_cpu_offload"] + optimizer_offload_fraction = config["megatron_cfg"]["optimizer"][ + "optimizer_offload_fraction" + ] + + if optimizer_cpu_offload: + # Currently, hybrid optimizer (partly on GPU and partly on CPU) is not supported because it conflicts with the way + # Nemo-rl handles the optimizer offload/onload between generation and training. So if using CPU optimizer the offload_fraction should be 1.0. + assert optimizer_offload_fraction == 1.0, ( + "Currently for optimizer offloading, only optimizer_offload_fraction=1.0 is supported" + ) + + +def _validate_chunking_config(config: PolicyConfig) -> None: + """Validate chunking configuration.""" + if ( + "logprob_chunk_size" in config + and config["logprob_chunk_size"] is not None + and config["logprob_chunk_size"] > 0 + ): + assert config["megatron_cfg"]["defer_fp32_logits"], ( + "defer_fp32_logits must be True if logprob_chunk_size is set" + ) + + +def _create_checkpoint_config( + pretrained_path: str, weights_path: Optional[str] +) -> CheckpointConfig: + """Create checkpoint configurations.""" + return CheckpointConfig( + save_interval=100, + save=weights_path, + load=weights_path, + pretrained_checkpoint=pretrained_path, + async_save=False, + fully_parallel_save=True, + fully_parallel_load=True, + load_rng=False, + ) + + +def _validate_training_config(config: PolicyConfig, model_cfg: Any) -> None: + """Validate training configuration.""" + assert "train_iters" in config["megatron_cfg"], ( + "train_iters must be set in megatron_cfg. For an example, see " + "https://github.com/NVIDIA-NeMo/RL/blob/bccbc377705a81a1f4b3c31ad9767bcc15f735a8/nemo_rl/algorithms/sft.py#L175-L179." + ) + + ## These settings are required for correct gradient computations in mcore + ## when calculate_per_token_loss is True, there is no scaling of the gradient in mcore, + ## so we handle the scaling in nemo-rl. + ## perform_initialization = True is a workaround to ensure the correct tensor parallel attributes are set + ## on the TP-sharded parameters. + model_cfg.calculate_per_token_loss = True + model_cfg.perform_initialization = True + + # MoE aux loss validation + assert ( + "aux_loss" not in model_cfg.moe_router_load_balancing_type + or model_cfg.moe_aux_loss_coeff == 0 + ), ( + "MoE aux loss is currently not supported due to a known bug in Megatron-LM. " + "See https://github.com/NVIDIA/Megatron-LM/issues/1984 for more details." + ) + + +def _validate_dtype_config( + dtype: torch.dtype, model_cfg: Any, optimizer_cfg: Any +) -> None: + # TODO: this validation should happen inside mbridge: https://github.com/NVIDIA-NeMo/Megatron-Bridge/issues/1665 + if dtype == torch.bfloat16: + assert model_cfg.bf16 == True, ( + "policy.megatron_cfg.model.bf16=True must be set if policy.precision=bfloat16. This is handled by nemo-rl so this indicates something is misconfigured." + ) + assert ( + optimizer_cfg.use_precision_aware_optimizer == False + or optimizer_cfg.bf16 == True + ), ( + "policy.megatron_cfg.optimizer.bf16=True must be set if policy.precision=bfloat16 when using use_precision_aware_optimizer=True" + ) + elif dtype == torch.float16: + assert model_cfg.fp16 == True, ( + "policy.megatron_cfg.model.fp16=True must be set if policy.precision=float16. This is handled by nemo-rl so this indicates something is misconfigured." + ) + assert ( + optimizer_cfg.use_precision_aware_optimizer == False + or optimizer_cfg.fp16 == True + ), ( + "policy.megatron_cfg.optimizer.fp16=True must be set if policy.precision=float16 when using use_precision_aware_optimizer=True" + ) + elif dtype == torch.float32: + assert model_cfg.bf16 == False and model_cfg.fp16 == False, ( + "policy.megatron_cfg.model.bf16=False and policy.megatron_cfg.model.fp16=False must be set if policy.precision=float32. This is handled by nemo-rl so this indicates something is misconfigured." + ) + assert optimizer_cfg.bf16 == False and optimizer_cfg.fp16 == False, ( + "policy.megatron_cfg.optimizer.bf16=False and policy.megatron_cfg.optimizer.fp16=False must be set if policy.precision=float32" + ) + + +def _create_megatron_config( + model_cfg: Any, + checkpoint_config: CheckpointConfig, + config: PolicyConfig, + hf_model_name: str, + dtype: torch.dtype, +) -> ConfigContainer: + """Create the final Megatron configuration container.""" + return ConfigContainer( + model=model_cfg, + checkpoint=checkpoint_config, + logger=LoggerConfig(logging_level=0), + train=TrainingConfig( + micro_batch_size=1, # ignored + global_batch_size=config["train_global_batch_size"], # ignored + train_iters=config["megatron_cfg"]["train_iters"], + ), + optimizer=OptimizerConfig(**config["megatron_cfg"]["optimizer"]), + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=config["megatron_cfg"][ + "distributed_data_parallel_config" + ]["grad_reduce_in_fp32"], + overlap_grad_reduce=config["megatron_cfg"][ + "distributed_data_parallel_config" + ]["overlap_grad_reduce"], + overlap_param_gather=config["megatron_cfg"][ + "distributed_data_parallel_config" + ]["overlap_param_gather"], + # we need to set average_in_collective=False with calculate_per_token_loss=T + # otherwise, mcore throws an assertion error. + average_in_collective=False, # Required with calculate_per_token_loss=True + use_distributed_optimizer=config["megatron_cfg"]["optimizer"][ + "use_distributed_optimizer" + ], + data_parallel_sharding_strategy=config["megatron_cfg"][ + "distributed_data_parallel_config" + ]["data_parallel_sharding_strategy"], + ), + scheduler=SchedulerConfig(**config["megatron_cfg"]["scheduler"]), + dataset=None, + tokenizer=TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model=hf_model_name, + ), + ) + + +def setup_model_and_optimizer( + policy_cfg: PolicyConfig, + megatron_cfg: ConfigContainer, + load_optimizer: bool = True, + get_embedding_ranks=None, # TODO @sahilj: What is this? + get_position_embedding_ranks=None, +): + state = GlobalState() + state.cfg = megatron_cfg + # TODO: Freeze state.cfg + + megatron_cfg.dist.external_gpu_device_mapping = True + initialize_megatron( + cfg=megatron_cfg, + get_embedding_ranks=get_embedding_ranks, + get_position_embedding_ranks=get_position_embedding_ranks, + ) + + if megatron_cfg.ft and megatron_cfg.ft.enable_ft_package: + fault_tolerance.setup(megatron_cfg, state) + fault_tolerance.maybe_setup_simulated_fault(megatron_cfg.ft) + + # Set pytorch JIT layer fusion options and warmup JIT functions. + set_jit_fusion_options(megatron_cfg.model, megatron_cfg.train.micro_batch_size) + + # Adjust the startup time so it reflects the largest value. + # This will be closer to what scheduler will see (outside of + # image ... launches. + start_time_tensor = torch.tensor( + [state.start_time], dtype=torch.double, device="cuda" + ) + torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN) + state.start_time = start_time_tensor.item() + + print( + "time to initialize megatron (seconds): {:.3f}".format( + time.time() - state.start_time + ) + ) + torch.distributed.barrier() + + # Context used for persisting some state between checkpoint saves. + checkpointing_context = init_checkpointing_context(megatron_cfg.checkpoint) + + # Tokenizer + build_tokenizer( + megatron_cfg.tokenizer, + make_vocab_size_divisible_by=megatron_cfg.model.make_vocab_size_divisible_by + // megatron_cfg.model.tensor_model_parallel_size, + tensor_model_parallel_size=megatron_cfg.model.tensor_model_parallel_size, + trust_remote_code=True, + ) + assert megatron_cfg.model.vocab_size, "vocab size must be specified in model config" + + torch.distributed.barrier() + + pre_wrap_hook = [] + + use_peft = policy_cfg["megatron_cfg"].get("peft", {}).get("enabled", False) + + mixed_precision_wrapper = Float16Module + if policy_cfg["megatron_cfg"]["freeze_moe_router"]: + if use_peft: + raise ValueError( + "Freezing the MOE router is not currently supported when using PEFT" + ) + + def freeze_moe_router(megatron_model): + if not isinstance(megatron_model, list): + megatron_model = [megatron_model] + for model_module in megatron_model: + # Handle both wrapped (Float16Module) and unwrapped models + if isinstance(model_module, Float16Module): + model_module = model_module.module + # Handle VLM models + if hasattr(model_module, "language_model"): + model_module = model_module.language_model + for layer in model_module.decoder.layers: + if hasattr(layer, "mlp") and hasattr(layer.mlp, "router"): + layer.mlp.router.weight.requires_grad = False + + mixed_precision_wrapper = MoEFloat16Module + pre_wrap_hook.extend([freeze_moe_router]) + + if use_peft: + peft_cfg = policy_cfg["megatron_cfg"].get("peft", {}) + peft = LoRA( + target_modules=peft_cfg["target_modules"], + exclude_modules=peft_cfg["exclude_modules"], + dim=peft_cfg["dim"], + alpha=peft_cfg["alpha"], + dropout=peft_cfg["dropout"], + dropout_position=peft_cfg["dropout_position"], + lora_A_init_method=peft_cfg["lora_A_init_method"], + lora_B_init_method=peft_cfg["lora_B_init_method"], + a2a_experimental=peft_cfg["a2a_experimental"], + lora_dtype=peft_cfg["lora_dtype"], + ) + else: + peft = None + megatron_cfg.peft = peft + + if megatron_cfg.peft is not None: + pre_peft_hook = _create_peft_pre_wrap_hook(megatron_cfg, state) + megatron_cfg.model.register_pre_wrap_hook(pre_peft_hook) + + def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]: + model = pre_peft_hook(model) + return model + + pre_wrap_hook.extend([composed_peft_hook]) + + # Model, optimizer, and learning rate. + model = get_model( + megatron_cfg.model, + megatron_cfg.ddp, + use_torch_fsdp2=megatron_cfg.dist.use_torch_fsdp2, + overlap_param_gather_with_optimizer_step=megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step, + data_parallel_random_init=megatron_cfg.rng.data_parallel_random_init, + pre_wrap_hook=pre_wrap_hook, + mixed_precision_wrapper=mixed_precision_wrapper, + ) + if load_optimizer: + optimizer, scheduler = setup_optimizer( + optimizer_config=megatron_cfg.optimizer, + scheduler_config=megatron_cfg.scheduler, + model=model, + use_gloo_process_groups=megatron_cfg.dist.use_gloo_process_groups, + ) + else: + optimizer = None + scheduler = None + + print("Model, optimizer, and learning rate scheduler built") + torch.distributed.barrier() + + if megatron_cfg.peft is not None: + should_load_checkpoint = ( + megatron_cfg.checkpoint.load is not None + and checkpoint_exists(megatron_cfg.checkpoint.load) + ) + if should_load_checkpoint: + # The finetune toggle is explicitly set to True in order to avoid loading optimizer and RNG states + # This is switched off here in order to load these states from the checkpoint + megatron_cfg.checkpoint.finetune = False + else: + should_load_checkpoint = ( + megatron_cfg.checkpoint.load is not None + and checkpoint_exists(megatron_cfg.checkpoint.load) + ) or ( + megatron_cfg.checkpoint.pretrained_checkpoint is not None + and checkpoint_exists(megatron_cfg.checkpoint.pretrained_checkpoint) + ) + + # Load checkpoint if applicable + if should_load_checkpoint: + load_checkpoint( + state, + model, + optimizer, + scheduler, + checkpointing_context=checkpointing_context, + skip_load_to_model_and_opt=HAVE_FSDP2 and megatron_cfg.dist.use_torch_fsdp2, + ) + print("Checkpoint loaded") + torch.distributed.barrier() + + # Set the param sync function for the model + param_sync_func = None + if megatron_cfg.ddp.overlap_param_gather and megatron_cfg.ddp.align_param_gather: + param_sync_func = [model_chunk.start_param_sync for model_chunk in model] + if len(model) == 1: + param_sync_func = param_sync_func[0] + + # Get the first model from the list + model = model[0] + + return ModelAndOptimizerState( + state, + model, + optimizer, + scheduler, + checkpointing_context, + param_sync_func, + ) + + +def handle_model_import( + config: PolicyConfig, + hf_model_name: str, + pretrained_path: str, + pt_checkpoint_exists: bool, +) -> None: + """Handle HF model import if checkpoint doesn't exist.""" + if pt_checkpoint_exists: + print(f"Checkpoint already exists at {pretrained_path}. Skipping import.") + else: + hf_config_overrides = config.get("hf_config_overrides", {}) or {} + import_model_from_hf_name( + hf_model_name, + pretrained_path, + config["megatron_cfg"], + **hf_config_overrides, + ) + + if parallel_state.model_parallel_is_initialized(): + print("Reinitializing model parallel after loading model state.") + parallel_state.destroy_model_parallel() + + +def setup_reference_model_state( + config: PolicyConfig, megatron_cfg: ConfigContainer, pretrained_path: str +) -> dict: + """Setup the reference model for inference and return its state dict.""" + # Create reference checkpoint config + ref_checkpoint_config = CheckpointConfig( + pretrained_checkpoint=pretrained_path, + save=None, + load=None, + fully_parallel_load=True, + load_rng=False, + ) + + ref_ckpt_context = init_checkpointing_context(ref_checkpoint_config) + + # Create a separate megatron config for the reference model + ref_megatron_cfg = ConfigContainer( + model=megatron_cfg.model, + checkpoint=ref_checkpoint_config, + logger=megatron_cfg.logger, + train=megatron_cfg.train, + optimizer=megatron_cfg.optimizer, + ddp=megatron_cfg.ddp, + scheduler=megatron_cfg.scheduler, + dataset=megatron_cfg.dataset, + tokenizer=megatron_cfg.tokenizer, + ) + + # Create a separate state object for the reference model + ref_state = GlobalState() + ref_state.cfg = ref_megatron_cfg + + # Configure mixed precision wrapper for reference model + ref_mixed_precision_wrapper = Float16Module + if config["megatron_cfg"].get("freeze_moe_router", False): + ref_mixed_precision_wrapper = MoEFloat16Module + + reference_model = get_model( + megatron_cfg.model, + megatron_cfg.ddp, + use_torch_fsdp2=megatron_cfg.dist.use_torch_fsdp2, + overlap_param_gather_with_optimizer_step=megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step, + pre_wrap_hook=megatron_cfg.rng.data_parallel_random_init, + mixed_precision_wrapper=ref_mixed_precision_wrapper, + ) + + print("Loading the Reference Model") + reference_state_dict = {} + + if ref_checkpoint_config.pretrained_checkpoint is not None and checkpoint_exists( + ref_checkpoint_config.pretrained_checkpoint + ): + load_checkpoint( + ref_state, + reference_model, + None, # no optimizer + None, # no scheduler + checkpointing_context=ref_ckpt_context, + skip_load_to_model_and_opt=HAVE_FSDP2 and megatron_cfg.dist.use_torch_fsdp2, + ) + reference_model = reference_model[0] + reference_model.eval() + + # Store reference state dict on CPU + for name, item in reference_model.state_dict().items(): + if isinstance(item, torch.Tensor): + cpu_item = item.detach().to(device="cpu", non_blocking=True, copy=True) + del item + else: + cpu_item = item + reference_state_dict[name] = cpu_item + print("Reference model loaded") + else: + print("Reference model not loaded") + + return reference_state_dict + + +def finalize_megatron_setup( + config: PolicyConfig, + megatron_cfg: ConfigContainer, + hf_model_name: str, + worker_sharding_annotations: NamedSharding, + model, + optimizer, +) -> tuple: + """Finalize the setup with remaining configurations. + + Returns: + Tuple of (megatron_tokenizer, megatron_bridge, should_disable_forward_pre_hook, dp_size) + """ + _update_model_config_funcs( + [model], + megatron_cfg.model, + megatron_cfg.ddp, + optimizer, + align_grad_reduce=megatron_cfg.dist.align_grad_reduce, + ) + + tokenizer_config = TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model=hf_model_name, + ) + + megatron_tokenizer = build_tokenizer( + tokenizer_config, + make_vocab_size_divisible_by=megatron_cfg.model.make_vocab_size_divisible_by + // config["megatron_cfg"]["tensor_model_parallel_size"], + tensor_model_parallel_size=config["megatron_cfg"]["tensor_model_parallel_size"], + trust_remote_code=True, + ) + + dp_size = worker_sharding_annotations.get_axis_size("data_parallel") + megatron_bridge = AutoBridge.from_hf_pretrained( + hf_model_name, trust_remote_code=True + ) + + should_disable_forward_pre_hook = ( + config["megatron_cfg"]["optimizer"]["use_distributed_optimizer"] + and config["megatron_cfg"]["distributed_data_parallel_config"][ + "overlap_param_gather" + ] + ) + + return megatron_tokenizer, megatron_bridge, should_disable_forward_pre_hook, dp_size + + +class MoEFloat16Module(Float16Module): + """Float 16 Module with the ability to keep the expert bias in float32. + + Attributes: + config (TransformerConfig): Transformer config + fp16 (bool) : Specifies if the model runs in fp16 mode + bf16 (bool) : Specifies if the model runs in bf16 mode + + Args: + config (TransformerConfig): The transformer config used to initalize the model + """ + + def __init__(self, config: TransformerConfig, module: torch.nn.Module): + super(MoEFloat16Module, self).__init__(config, module) + self.re_enable_float32_expert_bias() + + def re_enable_float32_expert_bias(self) -> None: + """Ensure MoE router expert bias stays in float32 for numerical stability. + + Walks the wrapped module to find MoE routers and invokes the + `_maintain_float32_expert_bias()` helper which recreates or casts the + expert bias tensors to float32 as required by Megatron-LM. + """ + module = self.module + # Handle VLM models where language model is nested + if hasattr(module, "language_model"): + module = module.language_model + if hasattr(module, "decoder") and hasattr(module.decoder, "layers"): + for layer in module.decoder.layers: + mlp = getattr(layer, "mlp", None) + router = getattr(mlp, "router", None) if mlp is not None else None + if router is not None and hasattr( + router, "_maintain_float32_expert_bias" + ): + router._maintain_float32_expert_bias() diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 63b6691f13..bedbfb739a 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -19,49 +19,19 @@ from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext from functools import partial -from typing import Any, Iterator, Optional, TypedDict, TypeVar, cast +from typing import Any, Iterator, Optional, TypeVar, cast import ray import torch -from megatron.bridge import AutoBridge -from megatron.bridge.models.model_provider import get_model -from megatron.bridge.peft.lora import LoRA -from megatron.bridge.training import fault_tolerance from megatron.bridge.training.checkpointing import ( - checkpoint_exists, - init_checkpointing_context, - load_checkpoint, maybe_finalize_async_save, save_checkpoint, ) -from megatron.bridge.training.config import ( - CheckpointConfig, - ConfigContainer, - DistributedDataParallelConfig, - LoggerConfig, - OptimizerConfig, - SchedulerConfig, - TokenizerConfig, - TrainingConfig, -) -from megatron.bridge.training.initialize import ( - initialize_megatron, - set_jit_fusion_options, -) -from megatron.bridge.training.optim import setup_optimizer -from megatron.bridge.training.setup import ( - _create_peft_pre_wrap_hook, - _update_model_config_funcs, -) -from megatron.bridge.training.state import GlobalState -from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer from megatron.bridge.training.utils.train_utils import ( logical_and_across_model_parallel_group, reduce_max_stat_across_model_parallel_group, ) from megatron.bridge.utils.common_utils import get_rank_safe -from megatron.bridge.utils.instantiate_utils import InstantiationMode -from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallel from megatron.core.distributed.fsdp.mcore_fsdp_adapter import ( @@ -87,11 +57,7 @@ ) from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import get_rerun_state_machine -from megatron.core.transformer import MegatronModule -from megatron.core.transformer.module import Float16Module -from megatron.core.transformer.transformer_config import TransformerConfig from megatron.training.utils import get_ltor_masks_and_position_ids -from ray.util.queue import Queue from transformers import PreTrainedTokenizerBase from nemo_rl.algorithms.interfaces import LossFunction, LossType @@ -116,55 +82,30 @@ forward_step_arbitrary_loss, get_moe_metrics, ) -from nemo_rl.models.megatron.community_import import import_model_from_hf_name +from nemo_rl.models.megatron.config import MegatronGenerationConfig +from nemo_rl.models.megatron.setup import ( + finalize_megatron_setup, + handle_model_import, + setup_distributed, + setup_model_and_optimizer, + setup_reference_model_state, + validate_and_set_config, + validate_model_paths, +) from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ( ColocatablePolicyInterface, LogprobOutputSpec, ) -from nemo_rl.models.policy.utils import ( - configure_dynamo_cache, - get_megatron_checkpoint_dir, - get_runtime_env_for_policy_worker, -) +from nemo_rl.models.policy.utils import get_runtime_env_for_policy_worker from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker from nemo_rl.models.policy.workers.patches import apply_transformer_engine_patch from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_producer -try: - from megatron.core.distributed import ( - TorchFullyShardedDataParallel as torch_FSDP, # noqa: F401 unused-import - ) - - HAVE_FSDP2 = True -except ImportError: - HAVE_FSDP2 = False - TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) -class MegatronGenerationConfig(TypedDict): - # Total GPU memory (in GB) allocated for KV cache buffers - buffer_size_gb: int - # Fraction of buffer reserved for guaranteed active requests - buffer_guaranteed_fraction: float - # Number of CUDA graphs to pre-compile for different batch sizes - num_cuda_graphs: int - # Size of each KV cache block in tokens (affects memory granularity) - block_size_tokens: int - # Enable CUDA graphs for prefill/context processing - use_cuda_graphs_for_non_decode_steps: bool - # Split long prefills into chunks for better memory management - enable_chunked_prefill: bool - # Unified memory usage level (0=disabled, higher values enable more aggressive paging) - unified_memory_level: int - # Maximum number of tokens to use in a single step. Analogous to vllm's max_num_batched_tokens. - # Can cause OOM if set too high so should be tuned with buffer_size_gb if OOMing. If set too - # low, then will only do 512 tokens at a time, which can be slow. - max_tokens: int - - def broadcast_object_across_pp_ranks(obj): """Broadcast an object across pipeline parallel ranks. @@ -218,273 +159,6 @@ def broadcast_object_across_pp_ranks(obj): return obj_list[0] -def setup_megatron_model( - policy_cfg: PolicyConfig, - cfg: ConfigContainer, - load_optimizer: bool = True, - get_embedding_ranks=None, # TODO @sahilj: What is this? - get_position_embedding_ranks=None, -): - state = GlobalState() - state.cfg = cfg - # TODO: Freeze state.cfg - - cfg.dist.external_gpu_device_mapping = True - initialize_megatron( - cfg=cfg, - get_embedding_ranks=get_embedding_ranks, - get_position_embedding_ranks=get_position_embedding_ranks, - ) - - if cfg.ft and cfg.ft.enable_ft_package: - fault_tolerance.setup(cfg, state) - fault_tolerance.maybe_setup_simulated_fault(cfg.ft) - - # Set pytorch JIT layer fusion options and warmup JIT functions. - set_jit_fusion_options(cfg.model, cfg.train.micro_batch_size) - - # Adjust the startup time so it reflects the largest value. - # This will be closer to what scheduler will see (outside of - # image ... launches. - start_time_tensor = torch.tensor( - [state.start_time], dtype=torch.double, device="cuda" - ) - torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN) - state.start_time = start_time_tensor.item() - - print( - "time to initialize megatron (seconds): {:.3f}".format( - time.time() - state.start_time - ) - ) - torch.distributed.barrier() - - # Context used for persisting some state between checkpoint saves. - checkpointing_context = init_checkpointing_context(cfg.checkpoint) - - # Tokenizer - build_tokenizer( - cfg.tokenizer, - make_vocab_size_divisible_by=cfg.model.make_vocab_size_divisible_by - // cfg.model.tensor_model_parallel_size, - tensor_model_parallel_size=cfg.model.tensor_model_parallel_size, - trust_remote_code=True, - ) - assert cfg.model.vocab_size, "vocab size must be specified in model config" - - torch.distributed.barrier() - - pre_wrap_hook = [] - mixed_precision_wrapper = Float16Module - - use_peft = policy_cfg["megatron_cfg"].get("peft", {}).get("enabled", False) - - if policy_cfg["megatron_cfg"]["freeze_moe_router"]: - if use_peft: - raise ValueError( - "Freezing the MOE router is not currently supported when using PEFT" - ) - - def freeze_moe_router(megatron_model): - if not isinstance(megatron_model, list): - megatron_model = [megatron_model] - for model_module in megatron_model: - # Handle both wrapped (Float16Module) and unwrapped models - if isinstance(model_module, Float16Module): - model_module = model_module.module - # Handle VLM models - if hasattr(model_module, "language_model"): - model_module = model_module.language_model - for layer in model_module.decoder.layers: - if hasattr(layer, "mlp") and hasattr(layer.mlp, "router"): - layer.mlp.router.weight.requires_grad = False - - mixed_precision_wrapper = CustomFloat16Module - pre_wrap_hook.extend([freeze_moe_router]) - - if use_peft: - peft_cfg = policy_cfg["megatron_cfg"].get("peft", {}) - peft = LoRA( - target_modules=peft_cfg["target_modules"], - exclude_modules=peft_cfg["exclude_modules"], - dim=peft_cfg["dim"], - alpha=peft_cfg["alpha"], - dropout=peft_cfg["dropout"], - dropout_position=peft_cfg["dropout_position"], - lora_A_init_method=peft_cfg["lora_A_init_method"], - lora_B_init_method=peft_cfg["lora_B_init_method"], - a2a_experimental=peft_cfg["a2a_experimental"], - lora_dtype=peft_cfg["lora_dtype"], - ) - else: - peft = None - cfg.peft = peft - - if cfg.peft is not None: - pre_peft_hook = _create_peft_pre_wrap_hook(cfg, state) - cfg.model.register_pre_wrap_hook(pre_peft_hook) - - def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]: - model = pre_peft_hook(model) - return model - - pre_wrap_hook.extend([composed_peft_hook]) - - # Model, optimizer, and learning rate. - model = get_model( - cfg.model, - cfg.ddp, - use_torch_fsdp2=cfg.dist.use_torch_fsdp2, - overlap_param_gather_with_optimizer_step=cfg.optimizer.overlap_param_gather_with_optimizer_step, - data_parallel_random_init=cfg.rng.data_parallel_random_init, - pre_wrap_hook=pre_wrap_hook, - mixed_precision_wrapper=mixed_precision_wrapper, - ) - - if load_optimizer: - optimizer, scheduler = setup_optimizer( - optimizer_config=cfg.optimizer, - scheduler_config=cfg.scheduler, - model=model, - use_gloo_process_groups=cfg.dist.use_gloo_process_groups, - ) - else: - optimizer = None - scheduler = None - - print("Model, optimizer, and learning rate scheduler built") - torch.distributed.barrier() - if cfg.peft is not None: - should_load_checkpoint = cfg.checkpoint.load is not None and checkpoint_exists( - cfg.checkpoint.load - ) - if should_load_checkpoint: - # The finetune toggle is explicitly set to True in order to avoid loading optimizer and RNG states - # This is switched off here in order to load these states from the checkpoint - cfg.checkpoint.finetune = False - else: - should_load_checkpoint = ( - cfg.checkpoint.load is not None and checkpoint_exists(cfg.checkpoint.load) - ) or ( - cfg.checkpoint.pretrained_checkpoint is not None - and checkpoint_exists(cfg.checkpoint.pretrained_checkpoint) - ) - - if should_load_checkpoint: - load_checkpoint( - state, - model, - optimizer, - scheduler, - checkpointing_context=checkpointing_context, - skip_load_to_model_and_opt=HAVE_FSDP2 and cfg.dist.use_torch_fsdp2, - ) - print("Checkpoint loaded") - torch.distributed.barrier() - - return state, model, optimizer, scheduler, checkpointing_context - - -def destroy_parallel_state(): - """Safely destroy parallel state and reset async call tracking. - - This function is called during initialization to clean up temporary distributed - state from model import operations. Resetting async call tracking ensures that - when the main Megatron distributed context is created, all ranks start with - consistent call_idx values for async checkpointing. - """ - if torch.distributed.is_initialized(): - try: - torch.distributed.barrier() - torch.distributed.destroy_process_group() - except: - pass # Ignore errors if already destroyed - if hasattr(parallel_state, "destroy_model_parallel"): - try: - parallel_state.destroy_model_parallel() - except: - pass # Ignore errors if already destroyed - - # Reset async calls queue to prevent call_idx mismatches after distributed context recreation - try: - import nemo.tron.utils.async_utils as nemo_async_utils - from megatron.core.dist_checkpointing.strategies.async_utils import ( - AsyncCallsQueue, - ) - - # Clean up any existing async callers first - old_call_idx = getattr(nemo_async_utils._async_calls_queue, "call_idx", None) - num_unfinalized = ( - nemo_async_utils._async_calls_queue.get_num_unfinalized_calls() - ) - if num_unfinalized > 0: - print( - f"[WARNING] Resetting async calls queue with {num_unfinalized} unfinalized calls" - ) - try: - nemo_async_utils._async_calls_queue.close() - except: - pass # Ignore errors during cleanup - # Reset the global async calls queue by creating a new instance - nemo_async_utils._async_calls_queue = AsyncCallsQueue() - print(f"[DEBUG] Reset NeMo async calls queue (old call_idx: {old_call_idx})") - except ImportError: - pass - - # Also reset the Megatron async calls queue if it exists - try: - import megatron.training.async_utils as megatron_async_utils - from megatron.core.dist_checkpointing.strategies.async_utils import ( - AsyncCallsQueue, - ) - - # Clean up any existing async callers first - old_call_idx = getattr( - megatron_async_utils._async_calls_queue, "call_idx", None - ) - num_unfinalized = ( - megatron_async_utils._async_calls_queue.get_num_unfinalized_calls() - ) - if num_unfinalized > 0: - print( - f"[WARNING] Resetting Megatron async calls queue with {num_unfinalized} unfinalized calls" - ) - try: - megatron_async_utils._async_calls_queue.close() - except: - pass # Ignore errors during cleanup - # Reset the Megatron global async calls queue as well - megatron_async_utils._async_calls_queue = AsyncCallsQueue() - print( - f"[DEBUG] Reset Megatron async calls queue (old call_idx: {old_call_idx})" - ) - except ImportError: - pass - - # Reset the third global async_calls instance in base strategy module - try: - import megatron.core.dist_checkpointing.strategies.base as base_strategy - from megatron.core.dist_checkpointing.strategies.async_utils import ( - AsyncCallsQueue, - ) - - # Clean up and reset the global async_calls in base strategy - old_call_idx = getattr(base_strategy.async_calls, "call_idx", None) - num_unfinalized = base_strategy.async_calls.get_num_unfinalized_calls() - if num_unfinalized > 0: - print( - f"[WARNING] Resetting base strategy async_calls with {num_unfinalized} unfinalized calls" - ) - try: - base_strategy.async_calls.close() - except: - pass - base_strategy.async_calls = AsyncCallsQueue() - print(f"[DEBUG] Reset base strategy async_calls (old call_idx: {old_call_idx})") - except ImportError: - pass - - @ray.remote( runtime_env=get_runtime_env_for_policy_worker("megatron_policy_worker") ) # pragma: no cover @@ -509,478 +183,100 @@ def __init__( init_reference_model: bool = True, *, worker_sharding_annotations: NamedSharding, - pre_init_communication_queue: Queue, **kwargs: Any, ): + """Initialize the MegatronPolicyWorker.""" + # Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files apply_transformer_engine_patch() - self.is_generation_colocated = None - if "generation" in config and config["generation"] is not None: - self.is_generation_colocated = config["generation"]["colocated"]["enabled"] - - # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. - # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. - if not self.is_generation_colocated: - os.environ["NCCL_CUMEM_ENABLE"] = "1" - self.cfg = config - dtype_map = { - "float32": torch.float32, - "bfloat16": torch.bfloat16, - "float16": torch.float16, - } - self.dtype = dtype_map[self.cfg["precision"]] - - self.optimizer_cpu_offload = self.cfg["megatron_cfg"]["optimizer"][ - "optimizer_cpu_offload" - ] - self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] - - # Reward models are not yet supported with Megatron. - if "reward_model_cfg" in self.cfg and self.cfg["reward_model_cfg"]["enabled"]: - raise NotImplementedError( - "Reward models are not yet supported with the Megatron backend, this issue is " - "tracked in https://github.com/NVIDIA-NeMo/RL/issues/720" - ) - - # Disable dynamo autotune_local_cache to avoid crash when there's already a cache - # with different order of node_bundles - configure_dynamo_cache() - - # cfg["model_name"] is allowed to be either an HF model name or a path to an HF checkpoint - # check if hf_model_name is a path - hf_model_name = self.cfg["model_name"] - # Check if the checkpoint already exists - hf_model_subdir = hf_model_name - if os.path.exists(hf_model_name): - hf_model_subdir = f"model_{hf_model_subdir.replace('/', '_')}" - - pretrained_path = f"{get_megatron_checkpoint_dir()}/{hf_model_subdir}" - pt_checkpoint_exists = os.path.exists(pretrained_path) and os.path.exists( - os.path.join(pretrained_path, "iter_0000000") - ) - - # Ensure clean slate before import - destroy_parallel_state() - # Set for rank for non-collocated to check which ranks to broadcast from + # Set rank for non-collocated to check which ranks to broadcast from self.rank = get_rank_safe() - # Need to initialize the process group before calling into Megatron-Bridge, otherwise Megatron-Bridge will try to set an incorrect device - torch.distributed.init_process_group("nccl") - if pt_checkpoint_exists: - print(f"Checkpoint already exists at {pretrained_path}. Skipping import.") - else: - hf_config_overrides = self.cfg.get("hf_config_overrides", {}) or {} - import_model_from_hf_name( - hf_model_name, - pretrained_path, - self.cfg["megatron_cfg"], - **hf_config_overrides, - ) - if parallel_state.model_parallel_is_initialized(): - print("Reinitializing model parallel after loading model state.") - parallel_state.destroy_model_parallel() + # Step 1: Setup distributed + setup_distributed() - pretrained_run_config = os.path.join( - pretrained_path, "iter_0000000/run_config.yaml" + # Step 2: Validate and setup model paths + hf_model_name, pretrained_path, pt_checkpoint_exists = validate_model_paths( + config + ) + # Handle model import if needed + handle_model_import( + config, hf_model_name, pretrained_path, pt_checkpoint_exists ) + # Store tokenizer self.tokenizer = tokenizer if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token - if not os.path.exists(pretrained_run_config): - raise FileNotFoundError( - f"Pretrained run config not found at {pretrained_run_config} on rank={get_rank_safe()}. This usually means that the one-time HF->mcore conversion on rank=0 saved to a directory not being mounted on this node. Please check " - ) - - try: - cfg_from_pretrained = ConfigContainer.from_yaml( - pretrained_run_config, mode=InstantiationMode.STRICT - ) - except Exception as e: - # Add helpful context as a note to the exception - e.add_note( - f"\n{'=' * 80}\n" - f"NOTE: A common cause of this error is when the HF->mcore converted checkpoint is\n" - f"created with an older version of megatron-bridge.\n" - f"If this checkpoint is old or was generated by a different code version,\n" - f"try deleting it and rerunning the code.\n" - f"The checkpoint will be automatically regenerated with the current version.\n\n" - f"Checkpoint location: {pretrained_path}\n" - f"{'=' * 80}" - ) - raise - model_cfg = cfg_from_pretrained.model - cfg_from_pretrained.logger = LoggerConfig() - - model_cfg.tensor_model_parallel_size = self.cfg["megatron_cfg"][ - "tensor_model_parallel_size" - ] - model_cfg.pipeline_model_parallel_size = self.cfg["megatron_cfg"][ - "pipeline_model_parallel_size" - ] - model_cfg.num_layers_in_first_pipeline_stage = self.cfg["megatron_cfg"][ - "num_layers_in_first_pipeline_stage" - ] - model_cfg.num_layers_in_last_pipeline_stage = self.cfg["megatron_cfg"][ - "num_layers_in_last_pipeline_stage" - ] - model_cfg.sequence_parallel = self.cfg["megatron_cfg"]["sequence_parallel"] - model_cfg.context_parallel_size = self.cfg["megatron_cfg"][ - "context_parallel_size" - ] - if model_cfg.context_parallel_size > 1: - assert self.cfg["sequence_packing"]["enabled"], ( - "Sequence Packing must be enabled to use Context Parallelism with MCore" - ) - model_cfg.expert_tensor_parallel_size = self.cfg["megatron_cfg"][ - "expert_tensor_parallel_size" - ] - model_cfg.expert_model_parallel_size = self.cfg["megatron_cfg"][ - "expert_model_parallel_size" - ] - - # Setting moe_router_dtype to higher precision (e.g. fp64) can improve numerical stability, - # especially when using many experts. - model_cfg.moe_router_dtype = self.cfg["megatron_cfg"]["moe_router_dtype"] - - # The below two configs (and "freeze_moe_router") are used to stabilize moe training - # by preventing updates to the moe router. We found that this is helpful in reducing - # logprob error during training. - - # Set this to "none" to disable load balancing loss. - model_cfg.moe_router_load_balancing_type = self.cfg["megatron_cfg"][ - "moe_router_load_balancing_type" - ] - # Set this to 0.0 to disable updates to the moe router expert bias - model_cfg.moe_router_bias_update_rate = self.cfg["megatron_cfg"][ - "moe_router_bias_update_rate" - ] - model_cfg.moe_enable_deepep = self.cfg["megatron_cfg"]["moe_enable_deepep"] - model_cfg.moe_token_dispatcher_type = self.cfg["megatron_cfg"][ - "moe_token_dispatcher_type" - ] - model_cfg.moe_shared_expert_overlap = self.cfg["megatron_cfg"][ - "moe_shared_expert_overlap" - ] + # Step 3: Setup model configuration + runtime_config = validate_and_set_config( + config, + self.rank, + hf_model_name, + pretrained_path, + weights_path, + tokenizer, + ) - model_cfg.moe_permute_fusion = self.cfg["megatron_cfg"]["moe_permute_fusion"] - if "layernorm_epsilon" in self.cfg["megatron_cfg"]: - model_cfg.layernorm_epsilon = self.cfg["megatron_cfg"]["layernorm_epsilon"] - - model_cfg.sequence_parallel = self.cfg["megatron_cfg"]["sequence_parallel"] - model_cfg.bf16 = self.dtype == torch.bfloat16 - model_cfg.fp16 = self.dtype == torch.float16 - if model_cfg.fp16: - assert not model_cfg.bf16, "fp16 and bf16 cannot be used together" - model_cfg.params_dtype = torch.float16 - elif model_cfg.bf16: - assert not model_cfg.fp16, "fp16 and bf16 cannot be used together" - model_cfg.params_dtype = torch.bfloat16 - else: - model_cfg.params_dtype = torch.float32 - model_cfg.pipeline_dtype = dtype_map[self.cfg["megatron_cfg"]["pipeline_dtype"]] - model_cfg.parallel_output = True - if self.cfg["megatron_cfg"]["activation_checkpointing"]: - model_cfg.recompute_granularity = "full" - model_cfg.recompute_method = "uniform" - model_cfg.recompute_num_layers = 1 - if not model_cfg.gated_linear_unit: - assert model_cfg.activation_func is not None, ( - "activation_func must be set if not using gated_linear_unit. This likely " - "indicates an issue in configuration conversion (e.g. activation func was " - "a lambda and couldn't be serialized). This is based on this check " - "https://github.com/NVIDIA/Megatron-LM/blob/1ab876ddc4c1893c76f26d775226a8d1dcdfb3d2/megatron/core/transformer/mlp.py#L174." - ) - model_cfg.apply_rope_fusion = self.cfg["megatron_cfg"]["apply_rope_fusion"] - model_cfg.bias_activation_fusion = self.cfg["megatron_cfg"][ - "bias_activation_fusion" - ] - fp8_cfg = self.cfg["megatron_cfg"].get("fp8_cfg", None) - self.fp8_cfg = fp8_cfg - if fp8_cfg is not None and fp8_cfg.get("enabled", False): - try: - model_cfg.fp8 = fp8_cfg["fp8"] - model_cfg.fp8_recipe = fp8_cfg["fp8_recipe"] - model_cfg.fp8_param = fp8_cfg["fp8_param"] - except KeyError as e: - raise KeyError(f"Missing key in fp8_cfg: {e}") - if model_cfg.fp8_param: - warnings.warn( - "Setting fp8_param=True sometimes causes NaN token_mult_prob_error, please use with caution. " - "Refer to https://github.com/NVIDIA-NeMo/RL/issues/1164 for latest updates with this issue." - ) + self.megatron_cfg = runtime_config.megatron_cfg + self.dtype = runtime_config.dtype + self.optimizer_cpu_offload = runtime_config.optimizer_cpu_offload + self.offload_optimizer_for_logprob = ( + runtime_config.offload_optimizer_for_logprob + ) + self.is_generation_colocated = runtime_config.is_generation_colocated + self.final_padded_vocab_size = runtime_config.final_padded_vocab_size - optimizer_cpu_offload = self.cfg["megatron_cfg"]["optimizer"][ - "optimizer_cpu_offload" - ] - optimizer_offload_fraction = self.cfg["megatron_cfg"]["optimizer"][ - "optimizer_offload_fraction" - ] - if optimizer_cpu_offload: - # Currently, hybrid optimizer (partly on GPU and partly on CPU) is not supported because it conflicts with the way - # Nemo-rl handles the optimizer offload/onload between generation and training. So if using CPU optimizer the offload_fraction should be 1.0. - assert optimizer_offload_fraction == 1.0, ( - "Currently for optimizer offloading, only optimizer_offload_fraction=1.0 is supported" - ) - if ( - "logprob_chunk_size" in self.cfg - and self.cfg["logprob_chunk_size"] is not None - and self.cfg["logprob_chunk_size"] > 0 - ): - assert self.cfg["megatron_cfg"]["defer_fp32_logits"], ( - "defer_fp32_logits must be True if logprob_chunk_size is set" - ) self.defer_fp32_logits = self.cfg["megatron_cfg"].get( "defer_fp32_logits", None - ) and (model_cfg.fp16 or model_cfg.bf16) - - checkpoint_config = CheckpointConfig( - save_interval=100, - save=weights_path, - load=weights_path, - pretrained_checkpoint=pretrained_path, # This is the path to the pretrained ckpt for the SFT case - async_save=False, # This doesn't work right now. - fully_parallel_save=True, - fully_parallel_load=True, # Enable fully parallel load - load_rng=False, - ) - ref_checkpoint_config = CheckpointConfig( - pretrained_checkpoint=pretrained_path, # This is the path to the pretrained ckpt for the SFT case - save=None, - load=None, - fully_parallel_load=True, # Enable fully parallel load - load_rng=False, - ) - - assert "train_iters" in self.cfg["megatron_cfg"], ( - "train_iters must be set in megatron_cfg. For an example, see " - "https://github.com/NVIDIA-NeMo/RL/blob/bccbc377705a81a1f4b3c31ad9767bcc15f735a8/nemo_rl/algorithms/sft.py#L175-L179." - ) + ) and (runtime_config.model_cfg.fp16 or runtime_config.model_cfg.bf16) - ## These settings are required for correct gradient computations in mcore - ## when calculate_per_token_loss is True, there is no scaling of the gradient in mcore, - ## so we handle the scaling in nemo-rl. - ## perform_initialization = True is a workaround to ensure the correct tensor parallel attributes are set - ## on the TP-sharded parameters. - model_cfg.calculate_per_token_loss = True - model_cfg.perform_initialization = True - - assert ( - "aux_loss" not in model_cfg.moe_router_load_balancing_type - or model_cfg.moe_aux_loss_coeff == 0 - ), ( - "MoE aux loss is currently not supported due to a known bug in Megatron-LM. " - "See https://github.com/NVIDIA/Megatron-LM/issues/1984 for more details." - ) + # Store FP8 config for later use + self.fp8_cfg = config["megatron_cfg"].get("fp8_cfg", None) - self.megatron_cfg = ConfigContainer( - model=model_cfg, - checkpoint=checkpoint_config, - logger=LoggerConfig(logging_level=0), - train=TrainingConfig( - micro_batch_size=1, # ignored - global_batch_size=self.cfg["train_global_batch_size"], # ignored - train_iters=self.cfg["megatron_cfg"][ - "train_iters" - ], # Set by algorithm setup - ), - optimizer=OptimizerConfig( - **self.cfg["megatron_cfg"]["optimizer"], - ), - ddp=DistributedDataParallelConfig( - check_for_nan_in_grad=True, - grad_reduce_in_fp32=self.cfg["megatron_cfg"][ - "distributed_data_parallel_config" - ]["grad_reduce_in_fp32"], - overlap_grad_reduce=self.cfg["megatron_cfg"][ - "distributed_data_parallel_config" - ]["overlap_grad_reduce"], - overlap_param_gather=self.cfg["megatron_cfg"][ - "distributed_data_parallel_config" - ]["overlap_param_gather"], - # we need to set average_in_collective=False with calculate_per_token_loss=True. - # otherwise, mcore throws an assertion error. - average_in_collective=False, - use_distributed_optimizer=self.cfg["megatron_cfg"]["optimizer"][ - "use_distributed_optimizer" - ], - data_parallel_sharding_strategy=self.cfg["megatron_cfg"][ - "distributed_data_parallel_config" - ]["data_parallel_sharding_strategy"], - ), - scheduler=SchedulerConfig( - **self.cfg["megatron_cfg"]["scheduler"], - ), - dataset=None, - tokenizer=TokenizerConfig( - tokenizer_type="HuggingFaceTokenizer", - tokenizer_model=hf_model_name, - ), - ) - # TODO: this validation should happen inside mbridge: https://github.com/NVIDIA-NeMo/Megatron-Bridge/issues/1665 - if self.dtype == torch.bfloat16: - assert self.megatron_cfg.model.bf16 == True, ( - "policy.megatron_cfg.model.bf16=True must be set if policy.precision=bfloat16. This is handled by nemo-rl so this indicates something is misconfigured." - ) - assert ( - self.megatron_cfg.optimizer.use_precision_aware_optimizer == False - or self.megatron_cfg.optimizer.bf16 == True - ), ( - "policy.megatron_cfg.optimizer.bf16=True must be set if policy.precision=bfloat16 when using use_precision_aware_optimizer=True" - ) - elif self.dtype == torch.float16: - assert self.megatron_cfg.model.fp16 == True, ( - "policy.megatron_cfg.model.fp16=True must be set if policy.precision=float16. This is handled by nemo-rl so this indicates something is misconfigured." - ) - assert ( - self.megatron_cfg.optimizer.use_precision_aware_optimizer == False - or self.megatron_cfg.optimizer.fp16 == True - ), ( - "policy.megatron_cfg.optimizer.fp16=True must be set if policy.precision=float16 when using use_precision_aware_optimizer=True" - ) - elif self.dtype == torch.float32: - assert ( - self.megatron_cfg.model.bf16 == False - and self.megatron_cfg.model.fp16 == False - ), ( - "policy.megatron_cfg.model.bf16=False and policy.megatron_cfg.model.fp16=False must be set if policy.precision=float32. This is handled by nemo-rl so this indicates something is misconfigured." - ) - assert ( - self.megatron_cfg.optimizer.bf16 == False - and self.megatron_cfg.optimizer.fp16 == False - ), ( - "policy.megatron_cfg.optimizer.bf16=False and policy.megatron_cfg.optimizer.fp16=False must be set if policy.precision=float32" - ) + # Validate configuration self.megatron_cfg.validate() - ( - self.mcore_state, - self.model, - self.optimizer, - self.scheduler, - self.checkpointing_context, - ) = setup_megatron_model( - policy_cfg=self.cfg, cfg=self.megatron_cfg, load_optimizer=init_optimizer + + # Step 4: Setup Megatron model and components + model_and_optimizer_state = setup_model_and_optimizer( + config, self.megatron_cfg, init_optimizer ) - # Set the param sync function for the model - if ( - self.megatron_cfg.ddp.overlap_param_gather - and self.megatron_cfg.ddp.align_param_gather - ): - self.megatron_cfg.param_sync_func = [ - model_chunk.start_param_sync for model_chunk in self.model - ] - if len(self.model) == 1: - self.megatron_cfg.param_sync_func = self.megatron_cfg.param_sync_func[0] + self.mcore_state = model_and_optimizer_state.state + self.model = model_and_optimizer_state.model + self.optimizer = model_and_optimizer_state.optimizer + self.scheduler = model_and_optimizer_state.scheduler + self.checkpointing_context = model_and_optimizer_state.checkpointing_context + param_sync_func = model_and_optimizer_state.param_sync_func - self.model = self.model[0] # Get the first model from the list + # Set the param sync function for the model if needed + if param_sync_func is not None: + self.megatron_cfg.param_sync_func = param_sync_func + # Step 5: Setup reference model if needed if init_reference_model: self.model = self.move_model(self.model, "cpu") - ref_ckpt_context = init_checkpointing_context(ref_checkpoint_config) - - # Create a separate megatron config for the reference model with the correct checkpoint config - ref_megatron_cfg = ConfigContainer( - model=self.megatron_cfg.model, - checkpoint=ref_checkpoint_config, # Use the reference checkpoint config - logger=self.megatron_cfg.logger, - train=self.megatron_cfg.train, - optimizer=self.megatron_cfg.optimizer, - ddp=self.megatron_cfg.ddp, - scheduler=self.megatron_cfg.scheduler, - dataset=self.megatron_cfg.dataset, - tokenizer=self.megatron_cfg.tokenizer, - ) - - # Create a separate state object for the reference model - ref_state = GlobalState() - ref_state.cfg = ref_megatron_cfg - - # Configure mixed precision wrapper for reference model - ref_mixed_precision_wrapper = Float16Module - if self.cfg["megatron_cfg"].get("freeze_moe_router", False): - ref_mixed_precision_wrapper = CustomFloat16Module - - reference_model = get_model( - self.megatron_cfg.model, - self.megatron_cfg.ddp, - use_torch_fsdp2=self.megatron_cfg.dist.use_torch_fsdp2, - overlap_param_gather_with_optimizer_step=self.megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step, - pre_wrap_hook=self.megatron_cfg.rng.data_parallel_random_init, - mixed_precision_wrapper=ref_mixed_precision_wrapper, + self.reference_state_dict = setup_reference_model_state( + config, self.megatron_cfg, pretrained_path ) - print("Loading the Reference Model") - if ( - ref_checkpoint_config.pretrained_checkpoint is not None - and checkpoint_exists(ref_checkpoint_config.pretrained_checkpoint) - ): - load_checkpoint( - ref_state, # Use the separate state object with ref checkpoint config - reference_model, - None, # no optimizer - None, # no scheduler - checkpointing_context=ref_ckpt_context, - skip_load_to_model_and_opt=HAVE_FSDP2 - and self.megatron_cfg.dist.use_torch_fsdp2, - ) - reference_model = reference_model[0] - reference_model.eval() - self.reference_state_dict = {} - for name, item in reference_model.state_dict().items(): - if isinstance(item, torch.Tensor): - cpu_item = item.detach().to( - device="cpu", non_blocking=True, copy=True - ) - del item - else: - cpu_item = item - self.reference_state_dict[name] = cpu_item - print("Reference model loaded") - else: - print("Reference model not loaded") - self.model = self.move_model(self.model, "cuda") - _update_model_config_funcs( - [self.model], - self.megatron_cfg.model, - self.megatron_cfg.ddp, + # Step 6: Finalize setup + ( + self.megatron_tokenizer, + self.megatron_bridge, + self.should_disable_forward_pre_hook, + self.dp_size, + ) = finalize_megatron_setup( + config, + self.megatron_cfg, + hf_model_name, + worker_sharding_annotations, + self.model, self.optimizer, - align_grad_reduce=self.megatron_cfg.dist.align_grad_reduce, - ) - - tokenizer_config = TokenizerConfig( - tokenizer_type="HuggingFaceTokenizer", - tokenizer_model=hf_model_name, - ) - - self.megatron_tokenizer = build_tokenizer( - tokenizer_config, - make_vocab_size_divisible_by=self.megatron_cfg.model.make_vocab_size_divisible_by - // self.cfg["megatron_cfg"]["tensor_model_parallel_size"], - tensor_model_parallel_size=self.cfg["megatron_cfg"][ - "tensor_model_parallel_size" - ], - trust_remote_code=True, - ) - self.final_padded_vocab_size = calculate_padded_vocab_size( - self.megatron_cfg.model.vocab_size, - self.megatron_cfg.model.make_vocab_size_divisible_by, - self.cfg["megatron_cfg"]["tensor_model_parallel_size"], - ) - self.dp_size = worker_sharding_annotations.get_axis_size("data_parallel") - self.megatron_bridge = AutoBridge.from_hf_pretrained( - hf_model_name, trust_remote_code=True - ) - - self.should_disable_forward_pre_hook = ( - self.cfg["megatron_cfg"]["optimizer"]["use_distributed_optimizer"] - and self.cfg["megatron_cfg"]["distributed_data_parallel_config"][ - "overlap_param_gather" - ] ) # vars used for refit @@ -989,9 +285,7 @@ def __init__( # [(mcore_param_name, estimated_memory), ...] # Note: here param name is local param name, with local layer number and # local expert id etc. - self.refit_conversion_tasks = ( - None # Meta data for conversion params from megatron bridge - ) + self.refit_conversion_tasks = None self.refit_conversion_tasks_current_index = None self.refit_param_info_mcore = None @@ -2722,40 +2016,3 @@ def _percentile(values: list[float], p: float) -> float: final_result = obj_list[0] # type: ignore return final_result - - -class CustomFloat16Module(Float16Module): - """Float 16 Module. - - Attributes: - config (TransformerConfig): Transformer config - fp16 (bool) : Specifies if the model runs in fp16 mode - bf16 (bool) : Specifies if the model runs in bf16 mode - - Args: - config (TransformerConfig): The transformer config used to initalize the model - """ - - def __init__(self, config: TransformerConfig, module: torch.nn.Module): - super(CustomFloat16Module, self).__init__(config, module) - self.re_enable_float32_expert_bias() - - def re_enable_float32_expert_bias(self) -> None: - """Ensure MoE router expert bias stays in float32 for numerical stability. - - Walks the wrapped module to find MoE routers and invokes the - `_maintain_float32_expert_bias()` helper which recreates or casts the - expert bias tensors to float32 as required by Megatron-LM. - """ - module = self.module - # Handle VLM models where language model is nested - if hasattr(module, "language_model"): - module = module.language_model - if hasattr(module, "decoder") and hasattr(module.decoder, "layers"): - for layer in module.decoder.layers: - mlp = getattr(layer, "mlp", None) - router = getattr(mlp, "router", None) if mlp is not None else None - if router is not None and hasattr( - router, "_maintain_float32_expert_bias" - ): - router._maintain_float32_expert_bias() diff --git a/tests/unit/models/megatron/test_megatron_setup.py b/tests/unit/models/megatron/test_megatron_setup.py new file mode 100644 index 0000000000..61c4bc7a75 --- /dev/null +++ b/tests/unit/models/megatron/test_megatron_setup.py @@ -0,0 +1,1139 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for Megatron setup utilities. + +This module tests the configuration validation and setup functions in +nemo_rl.models.megatron.setup, focusing on: +- Configuration validation functions +- Parallelism configuration application +- Precision and dtype configuration +- Checkpoint configuration creation +- Model path validation +""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + + +@pytest.mark.mcore +class TestValidateModelPaths: + """Tests for validate_model_paths function.""" + + def test_model_name_is_hf_model(self, tmp_path): + """Test with a HuggingFace model name (not a local path).""" + from nemo_rl.models.megatron.setup import validate_model_paths + + config = {"model_name": "meta-llama/Llama-3.2-1B"} + + with patch( + "nemo_rl.models.megatron.setup.get_megatron_checkpoint_dir", + return_value=str(tmp_path), + ): + hf_model_name, pretrained_path, pt_checkpoint_exists = validate_model_paths( + config + ) + + assert hf_model_name == "meta-llama/Llama-3.2-1B" + assert pretrained_path == f"{tmp_path}/meta-llama/Llama-3.2-1B" + assert pt_checkpoint_exists is False + + def test_model_name_is_local_path(self, tmp_path): + """Test with a local path as model name.""" + from nemo_rl.models.megatron.setup import validate_model_paths + + local_model_path = tmp_path / "local_model" + local_model_path.mkdir() + + config = {"model_name": str(local_model_path)} + + with patch( + "nemo_rl.models.megatron.setup.get_megatron_checkpoint_dir", + return_value=str(tmp_path / "checkpoints"), + ): + hf_model_name, pretrained_path, pt_checkpoint_exists = validate_model_paths( + config + ) + + assert hf_model_name == str(local_model_path) + # Local path should be converted to model_ format + assert "model_" in pretrained_path + assert pt_checkpoint_exists is False + + def test_checkpoint_exists(self, tmp_path): + """Test when a Megatron checkpoint already exists.""" + from nemo_rl.models.megatron.setup import validate_model_paths + + # Create the checkpoint directory structure + checkpoint_dir = tmp_path / "checkpoints" / "test-model" + iter_dir = checkpoint_dir / "iter_0000000" + iter_dir.mkdir(parents=True) + + config = {"model_name": "test-model"} + + with patch( + "nemo_rl.models.megatron.setup.get_megatron_checkpoint_dir", + return_value=str(tmp_path / "checkpoints"), + ): + hf_model_name, pretrained_path, pt_checkpoint_exists = validate_model_paths( + config + ) + + assert hf_model_name == "test-model" + assert pt_checkpoint_exists is True + + +@pytest.mark.mcore +class TestApplyParallelismConfig: + """Tests for _apply_parallelism_config function.""" + + def test_basic_parallelism_config(self): + """Test applying basic parallelism configuration.""" + from nemo_rl.models.megatron.setup import _apply_parallelism_config + + model_cfg = MagicMock() + config = { + "megatron_cfg": { + "tensor_model_parallel_size": 4, + "pipeline_model_parallel_size": 2, + "num_layers_in_first_pipeline_stage": None, + "num_layers_in_last_pipeline_stage": None, + "sequence_parallel": True, + "context_parallel_size": 1, + }, + "sequence_packing": {"enabled": False}, + } + + _apply_parallelism_config(model_cfg, config) + + assert model_cfg.tensor_model_parallel_size == 4 + assert model_cfg.pipeline_model_parallel_size == 2 + assert model_cfg.sequence_parallel is True + assert model_cfg.context_parallel_size == 1 + + def test_context_parallel_requires_sequence_packing(self): + """Test that context parallelism > 1 requires sequence packing.""" + from nemo_rl.models.megatron.setup import _apply_parallelism_config + + model_cfg = MagicMock() + config = { + "megatron_cfg": { + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "num_layers_in_first_pipeline_stage": None, + "num_layers_in_last_pipeline_stage": None, + "sequence_parallel": False, + "context_parallel_size": 2, + }, + "sequence_packing": {"enabled": False}, + } + + with pytest.raises(AssertionError) as exc_info: + _apply_parallelism_config(model_cfg, config) + + assert "Sequence Packing must be enabled" in str(exc_info.value) + + def test_context_parallel_with_sequence_packing(self): + """Test context parallelism with sequence packing enabled.""" + from nemo_rl.models.megatron.setup import _apply_parallelism_config + + model_cfg = MagicMock() + config = { + "megatron_cfg": { + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "num_layers_in_first_pipeline_stage": None, + "num_layers_in_last_pipeline_stage": None, + "sequence_parallel": False, + "context_parallel_size": 4, + }, + "sequence_packing": {"enabled": True}, + } + + _apply_parallelism_config(model_cfg, config) + + assert model_cfg.context_parallel_size == 4 + + +@pytest.mark.mcore +class TestApplyMoeConfig: + """Tests for _apply_moe_config function.""" + + def test_moe_configuration(self): + """Test applying MoE configuration.""" + from nemo_rl.models.megatron.setup import _apply_moe_config + + model_cfg = MagicMock() + config = { + "megatron_cfg": { + "expert_tensor_parallel_size": 2, + "expert_model_parallel_size": 4, + "moe_router_dtype": "float32", + "moe_router_load_balancing_type": "none", + "moe_router_bias_update_rate": 0.0, + "moe_permute_fusion": True, + "moe_enable_deepep": False, + "moe_token_dispatcher_type": "allgather", + "moe_shared_expert_overlap": True, + } + } + + _apply_moe_config(model_cfg, config) + + assert model_cfg.expert_tensor_parallel_size == 2 + assert model_cfg.expert_model_parallel_size == 4 + assert model_cfg.moe_router_dtype == "float32" + assert model_cfg.moe_router_load_balancing_type == "none" + assert model_cfg.moe_router_bias_update_rate == 0.0 + assert model_cfg.moe_permute_fusion is True + assert model_cfg.moe_enable_deepep is False + assert model_cfg.moe_token_dispatcher_type == "allgather" + assert model_cfg.moe_shared_expert_overlap is True + + +@pytest.mark.mcore +class TestApplyPrecisionConfig: + """Tests for _apply_precision_config function.""" + + @pytest.mark.parametrize( + "dtype,expected_bf16,expected_fp16,expected_params_dtype", + [ + (torch.bfloat16, True, False, torch.bfloat16), + (torch.float16, False, True, torch.float16), + (torch.float32, False, False, torch.float32), + ], + ids=["bfloat16", "float16", "float32"], + ) + def test_precision_configurations( + self, dtype, expected_bf16, expected_fp16, expected_params_dtype + ): + """Test precision configuration for different dtypes.""" + from nemo_rl.models.megatron.setup import _apply_precision_config + + model_cfg = MagicMock() + model_cfg.bf16 = False + model_cfg.fp16 = False + config = { + "megatron_cfg": { + "pipeline_dtype": "bfloat16", + } + } + + _apply_precision_config(model_cfg, config, dtype) + + assert model_cfg.bf16 == expected_bf16 + assert model_cfg.fp16 == expected_fp16 + assert model_cfg.params_dtype == expected_params_dtype + + def test_pipeline_dtype_mapping(self): + """Test that pipeline dtype is correctly mapped.""" + from nemo_rl.models.megatron.setup import _apply_precision_config + + model_cfg = MagicMock() + model_cfg.bf16 = False + model_cfg.fp16 = False + + for dtype_str, expected_dtype in [ + ("float32", torch.float32), + ("bfloat16", torch.bfloat16), + ("float16", torch.float16), + ]: + config = { + "megatron_cfg": { + "pipeline_dtype": dtype_str, + } + } + _apply_precision_config(model_cfg, config, torch.float32) + assert model_cfg.pipeline_dtype == expected_dtype + + +@pytest.mark.mcore +class TestApplyPerformanceConfig: + """Tests for _apply_performance_config function.""" + + def test_basic_performance_config(self): + """Test applying basic performance configuration.""" + from nemo_rl.models.megatron.setup import _apply_performance_config + + model_cfg = MagicMock() + model_cfg.gated_linear_unit = True + config = { + "megatron_cfg": { + "activation_checkpointing": False, + "apply_rope_fusion": True, + "bias_activation_fusion": True, + } + } + + _apply_performance_config(model_cfg, config) + + assert model_cfg.parallel_output is True + assert model_cfg.apply_rope_fusion is True + assert model_cfg.bias_activation_fusion is True + + def test_activation_checkpointing_enabled(self): + """Test activation checkpointing configuration.""" + from nemo_rl.models.megatron.setup import _apply_performance_config + + model_cfg = MagicMock() + model_cfg.gated_linear_unit = True + config = { + "megatron_cfg": { + "activation_checkpointing": True, + "apply_rope_fusion": False, + "bias_activation_fusion": False, + } + } + + _apply_performance_config(model_cfg, config) + + assert model_cfg.recompute_granularity == "full" + assert model_cfg.recompute_method == "uniform" + assert model_cfg.recompute_num_layers == 1 + + def test_activation_func_required_when_not_gated(self): + """Test that activation_func is required when not using gated_linear_unit.""" + from nemo_rl.models.megatron.setup import _apply_performance_config + + model_cfg = MagicMock() + model_cfg.gated_linear_unit = False + model_cfg.activation_func = None + config = { + "megatron_cfg": { + "activation_checkpointing": False, + "apply_rope_fusion": False, + "bias_activation_fusion": False, + } + } + + with pytest.raises(AssertionError) as exc_info: + _apply_performance_config(model_cfg, config) + + assert "activation_func must be set" in str(exc_info.value) + + def test_fp8_configuration(self): + """Test FP8 configuration.""" + from nemo_rl.models.megatron.setup import _apply_performance_config + + model_cfg = MagicMock() + model_cfg.gated_linear_unit = True + config = { + "megatron_cfg": { + "activation_checkpointing": False, + "apply_rope_fusion": False, + "bias_activation_fusion": False, + "fp8_cfg": { + "enabled": True, + "fp8": "e4m3", + "fp8_recipe": "default", + "fp8_param": False, + }, + } + } + + _apply_performance_config(model_cfg, config) + + assert model_cfg.fp8 == "e4m3" + assert model_cfg.fp8_recipe == "default" + assert model_cfg.fp8_param is False + + def test_fp8_param_warning(self): + """Test that fp8_param=True generates a warning.""" + from nemo_rl.models.megatron.setup import _apply_performance_config + + model_cfg = MagicMock() + model_cfg.gated_linear_unit = True + config = { + "megatron_cfg": { + "activation_checkpointing": False, + "apply_rope_fusion": False, + "bias_activation_fusion": False, + "fp8_cfg": { + "enabled": True, + "fp8": "e4m3", + "fp8_recipe": "default", + "fp8_param": True, + }, + } + } + + with pytest.warns(UserWarning, match="fp8_param=True sometimes causes NaN"): + _apply_performance_config(model_cfg, config) + + +@pytest.mark.mcore +class TestValidateOptimizerConfig: + """Tests for _validate_optimizer_config function.""" + + def test_cpu_offload_requires_full_fraction(self): + """Test that CPU offload requires offload_fraction=1.0.""" + from nemo_rl.models.megatron.setup import _validate_optimizer_config + + config = { + "megatron_cfg": { + "optimizer": { + "optimizer_cpu_offload": True, + "optimizer_offload_fraction": 0.5, + } + } + } + + with pytest.raises(AssertionError) as exc_info: + _validate_optimizer_config(config) + + assert "optimizer_offload_fraction=1.0" in str(exc_info.value) + + def test_cpu_offload_with_full_fraction(self): + """Test that CPU offload works with full fraction.""" + from nemo_rl.models.megatron.setup import _validate_optimizer_config + + config = { + "megatron_cfg": { + "optimizer": { + "optimizer_cpu_offload": True, + "optimizer_offload_fraction": 1.0, + } + } + } + + # Should not raise + _validate_optimizer_config(config) + + def test_no_cpu_offload(self): + """Test configuration without CPU offload.""" + from nemo_rl.models.megatron.setup import _validate_optimizer_config + + config = { + "megatron_cfg": { + "optimizer": { + "optimizer_cpu_offload": False, + "optimizer_offload_fraction": 0.5, # Should be ignored + } + } + } + + # Should not raise + _validate_optimizer_config(config) + + +@pytest.mark.mcore +class TestValidateChunkingConfig: + """Tests for _validate_chunking_config function.""" + + def test_logprob_chunk_requires_defer_fp32_logits(self): + """Test that logprob chunking requires defer_fp32_logits=True.""" + from nemo_rl.models.megatron.setup import _validate_chunking_config + + config = { + "logprob_chunk_size": 1024, + "megatron_cfg": { + "defer_fp32_logits": False, + }, + } + + with pytest.raises(AssertionError) as exc_info: + _validate_chunking_config(config) + + assert "defer_fp32_logits must be True" in str(exc_info.value) + + def test_logprob_chunk_with_defer_fp32_logits(self): + """Test that logprob chunking works with defer_fp32_logits=True.""" + from nemo_rl.models.megatron.setup import _validate_chunking_config + + config = { + "logprob_chunk_size": 1024, + "megatron_cfg": { + "defer_fp32_logits": True, + }, + } + + # Should not raise + _validate_chunking_config(config) + + @pytest.mark.parametrize( + "logprob_chunk_size", + [None, 0, -1], + ids=["none", "zero", "negative"], + ) + def test_no_chunking_skips_validation(self, logprob_chunk_size): + """Test that validation is skipped when chunking is disabled.""" + from nemo_rl.models.megatron.setup import _validate_chunking_config + + config = { + "logprob_chunk_size": logprob_chunk_size, + "megatron_cfg": { + "defer_fp32_logits": False, # Doesn't matter when chunking is disabled + }, + } + + # Should not raise + _validate_chunking_config(config) + + def test_missing_logprob_chunk_size(self): + """Test that missing logprob_chunk_size is handled.""" + from nemo_rl.models.megatron.setup import _validate_chunking_config + + config = { + "megatron_cfg": { + "defer_fp32_logits": False, + }, + } + + # Should not raise + _validate_chunking_config(config) + + +@pytest.mark.mcore +class TestCreateCheckpointConfig: + """Tests for _create_checkpoint_config function.""" + + def test_basic_checkpoint_config(self, tmp_path): + """Test creating basic checkpoint configuration.""" + from nemo_rl.models.megatron.setup import _create_checkpoint_config + + pretrained_path = str(tmp_path / "pretrained") + weights_path = str(tmp_path / "weights") + + checkpoint_config = _create_checkpoint_config(pretrained_path, weights_path) + + assert checkpoint_config.save == weights_path + assert checkpoint_config.load == weights_path + assert checkpoint_config.pretrained_checkpoint == pretrained_path + assert checkpoint_config.async_save is False + assert checkpoint_config.fully_parallel_save is True + assert checkpoint_config.fully_parallel_load is True + assert checkpoint_config.load_rng is False + + +@pytest.mark.mcore +class TestValidateTrainingConfig: + """Tests for _validate_training_config function.""" + + def test_train_iters_required(self): + """Test that train_iters must be set.""" + from nemo_rl.models.megatron.setup import _validate_training_config + + model_cfg = MagicMock() + model_cfg.moe_router_load_balancing_type = "none" + model_cfg.moe_aux_loss_coeff = 0 + config = { + "megatron_cfg": {}, + } + + with pytest.raises(AssertionError) as exc_info: + _validate_training_config(config, model_cfg) + + assert "train_iters must be set" in str(exc_info.value) + + def test_training_config_sets_required_flags(self): + """Test that training config sets required model flags.""" + from nemo_rl.models.megatron.setup import _validate_training_config + + model_cfg = MagicMock() + model_cfg.moe_router_load_balancing_type = "none" + model_cfg.moe_aux_loss_coeff = 0 + config = { + "megatron_cfg": { + "train_iters": 1000, + }, + } + + _validate_training_config(config, model_cfg) + + assert model_cfg.calculate_per_token_loss is True + assert model_cfg.perform_initialization is True + + def test_moe_aux_loss_not_supported(self): + """Test that MoE aux loss is not supported.""" + from nemo_rl.models.megatron.setup import _validate_training_config + + model_cfg = MagicMock() + model_cfg.moe_router_load_balancing_type = "aux_loss" + model_cfg.moe_aux_loss_coeff = 0.1 # Non-zero + config = { + "megatron_cfg": { + "train_iters": 1000, + }, + } + + with pytest.raises(AssertionError) as exc_info: + _validate_training_config(config, model_cfg) + + assert "MoE aux loss is currently not supported" in str(exc_info.value) + + def test_moe_aux_loss_with_zero_coeff_is_ok(self): + """Test that MoE aux loss with zero coefficient is allowed.""" + from nemo_rl.models.megatron.setup import _validate_training_config + + model_cfg = MagicMock() + model_cfg.moe_router_load_balancing_type = "aux_loss" + model_cfg.moe_aux_loss_coeff = 0 # Zero is OK + config = { + "megatron_cfg": { + "train_iters": 1000, + }, + } + + # Should not raise + _validate_training_config(config, model_cfg) + + +@pytest.mark.mcore +class TestValidateDtypeConfig: + """Tests for _validate_dtype_config function.""" + + def test_bfloat16_validation(self): + """Test bfloat16 dtype validation.""" + from nemo_rl.models.megatron.setup import _validate_dtype_config + + model_cfg = MagicMock() + model_cfg.bf16 = True + model_cfg.fp16 = False + + optimizer_cfg = MagicMock() + optimizer_cfg.use_precision_aware_optimizer = False + optimizer_cfg.bf16 = False + optimizer_cfg.fp16 = False + + # Should not raise + _validate_dtype_config(torch.bfloat16, model_cfg, optimizer_cfg) + + def test_bfloat16_model_flag_mismatch(self): + """Test bfloat16 validation fails when model.bf16=False.""" + from nemo_rl.models.megatron.setup import _validate_dtype_config + + model_cfg = MagicMock() + model_cfg.bf16 = False # Mismatch! + model_cfg.fp16 = False + + optimizer_cfg = MagicMock() + optimizer_cfg.use_precision_aware_optimizer = False + + with pytest.raises(AssertionError) as exc_info: + _validate_dtype_config(torch.bfloat16, model_cfg, optimizer_cfg) + + assert "bf16=True must be set" in str(exc_info.value) + + def test_bfloat16_with_precision_aware_optimizer(self): + """Test bfloat16 with precision aware optimizer requires optimizer.bf16=True.""" + from nemo_rl.models.megatron.setup import _validate_dtype_config + + model_cfg = MagicMock() + model_cfg.bf16 = True + model_cfg.fp16 = False + + optimizer_cfg = MagicMock() + optimizer_cfg.use_precision_aware_optimizer = True + optimizer_cfg.bf16 = False # Mismatch! + + with pytest.raises(AssertionError) as exc_info: + _validate_dtype_config(torch.bfloat16, model_cfg, optimizer_cfg) + + assert "optimizer.bf16=True must be set" in str(exc_info.value) + + def test_float16_validation(self): + """Test float16 dtype validation.""" + from nemo_rl.models.megatron.setup import _validate_dtype_config + + model_cfg = MagicMock() + model_cfg.bf16 = False + model_cfg.fp16 = True + + optimizer_cfg = MagicMock() + optimizer_cfg.use_precision_aware_optimizer = False + + # Should not raise + _validate_dtype_config(torch.float16, model_cfg, optimizer_cfg) + + def test_float16_model_flag_mismatch(self): + """Test float16 validation fails when model.fp16=False.""" + from nemo_rl.models.megatron.setup import _validate_dtype_config + + model_cfg = MagicMock() + model_cfg.bf16 = False + model_cfg.fp16 = False # Mismatch! + + optimizer_cfg = MagicMock() + optimizer_cfg.use_precision_aware_optimizer = False + + with pytest.raises(AssertionError) as exc_info: + _validate_dtype_config(torch.float16, model_cfg, optimizer_cfg) + + assert "fp16=True must be set" in str(exc_info.value) + + def test_float32_validation(self): + """Test float32 dtype validation.""" + from nemo_rl.models.megatron.setup import _validate_dtype_config + + model_cfg = MagicMock() + model_cfg.bf16 = False + model_cfg.fp16 = False + + optimizer_cfg = MagicMock() + optimizer_cfg.bf16 = False + optimizer_cfg.fp16 = False + + # Should not raise + _validate_dtype_config(torch.float32, model_cfg, optimizer_cfg) + + def test_float32_with_bf16_model_flag(self): + """Test float32 validation fails when model has bf16=True.""" + from nemo_rl.models.megatron.setup import _validate_dtype_config + + model_cfg = MagicMock() + model_cfg.bf16 = True # Mismatch! + model_cfg.fp16 = False + + optimizer_cfg = MagicMock() + optimizer_cfg.bf16 = False + optimizer_cfg.fp16 = False + + with pytest.raises(AssertionError) as exc_info: + _validate_dtype_config(torch.float32, model_cfg, optimizer_cfg) + + assert "bf16=False" in str(exc_info.value) + + def test_float32_with_fp16_optimizer_flag(self): + """Test float32 validation fails when optimizer has fp16=True.""" + from nemo_rl.models.megatron.setup import _validate_dtype_config + + model_cfg = MagicMock() + model_cfg.bf16 = False + model_cfg.fp16 = False + + optimizer_cfg = MagicMock() + optimizer_cfg.bf16 = False + optimizer_cfg.fp16 = True # Mismatch! + + with pytest.raises(AssertionError) as exc_info: + _validate_dtype_config(torch.float32, model_cfg, optimizer_cfg) + + assert "optimizer" in str(exc_info.value).lower() + + +@pytest.mark.mcore +class TestValidateAndSetConfig: + """Tests for validate_and_set_config function.""" + + def test_reward_model_not_supported(self): + """Test that reward models are not supported.""" + from nemo_rl.models.megatron.setup import validate_and_set_config + + config = { + "reward_model_cfg": {"enabled": True}, + "precision": "bfloat16", + "megatron_cfg": { + "optimizer": { + "optimizer_cpu_offload": False, + }, + }, + "offload_optimizer_for_logprob": False, + } + + with pytest.raises(NotImplementedError) as exc_info: + validate_and_set_config( + config=config, + rank=0, + hf_model_name="test-model", + pretrained_path="/path/to/model", + weights_path=None, + tokenizer=MagicMock(), + ) + + assert "Reward models are not yet supported" in str(exc_info.value) + + def test_generation_colocation_detection(self): + """Test that generation colocation is properly detected.""" + # This test would require more mocking to fully test + # For now, we just verify the config parsing works + from nemo_rl.models.megatron.setup import validate_and_set_config + + config = { + "generation": { + "colocated": {"enabled": True}, + }, + "precision": "bfloat16", + "megatron_cfg": { + "optimizer": { + "optimizer_cpu_offload": False, + }, + "tensor_model_parallel_size": 2, + }, + "offload_optimizer_for_logprob": False, + } + + # The function would fail on setup_model_config, but we test the initial parsing + with patch( + "nemo_rl.models.megatron.setup.setup_model_config" + ) as mock_setup_model_config: + mock_megatron_cfg = MagicMock() + mock_megatron_cfg.model.vocab_size = 32000 + mock_setup_model_config.return_value = (mock_megatron_cfg, MagicMock()) + + with patch( + "nemo_rl.models.megatron.setup.calculate_padded_vocab_size", + return_value=32000, + ): + runtime_config = validate_and_set_config( + config=config, + rank=0, + hf_model_name="test-model", + pretrained_path="/path/to/model", + weights_path=None, + tokenizer=MagicMock(), + ) + + assert runtime_config.is_generation_colocated is True + + +@pytest.mark.mcore +class TestRuntimeConfigNamedTuple: + """Tests for RuntimeConfig named tuple.""" + + def test_runtime_config_fields(self): + """Test that RuntimeConfig has all expected fields.""" + from nemo_rl.models.megatron.config import RuntimeConfig + + runtime_config = RuntimeConfig( + megatron_cfg=MagicMock(), + model_cfg=MagicMock(), + dtype=torch.bfloat16, + optimizer_cpu_offload=False, + offload_optimizer_for_logprob=True, + is_generation_colocated=True, + final_padded_vocab_size=32000, + ) + + assert runtime_config.dtype == torch.bfloat16 + assert runtime_config.optimizer_cpu_offload is False + assert runtime_config.offload_optimizer_for_logprob is True + assert runtime_config.is_generation_colocated is True + assert runtime_config.final_padded_vocab_size == 32000 + + +@pytest.mark.mcore +class TestModelAndOptimizerStateNamedTuple: + """Tests for ModelAndOptimizerState named tuple.""" + + def test_model_and_optimizer_state_fields(self): + """Test that ModelAndOptimizerState has all expected fields.""" + from nemo_rl.models.megatron.config import ModelAndOptimizerState + + state = ModelAndOptimizerState( + state=MagicMock(), + model=MagicMock(), + optimizer=MagicMock(), + scheduler=MagicMock(), + checkpointing_context={"test": "context"}, + param_sync_func=lambda: None, + ) + + assert state.checkpointing_context == {"test": "context"} + assert callable(state.param_sync_func) + + +@pytest.mark.mcore +class TestHandleModelImport: + """Tests for handle_model_import function.""" + + def test_skip_import_when_checkpoint_exists(self, tmp_path, capsys): + """Test that import is skipped when checkpoint exists.""" + from nemo_rl.models.megatron.setup import handle_model_import + + pretrained_path = str(tmp_path / "model") + config = {"model_name": "test-model", "megatron_cfg": {}} + + handle_model_import( + config, "test-model", pretrained_path, pt_checkpoint_exists=True + ) + + captured = capsys.readouterr() + assert "Checkpoint already exists" in captured.out + + @patch("nemo_rl.models.megatron.setup.import_model_from_hf_name") + @patch("nemo_rl.models.megatron.setup.parallel_state") + def test_import_when_checkpoint_missing(self, mock_ps, mock_import, tmp_path): + """Test that model is imported when checkpoint doesn't exist.""" + from nemo_rl.models.megatron.setup import handle_model_import + + mock_ps.model_parallel_is_initialized.return_value = False + + pretrained_path = str(tmp_path / "model") + config = { + "model_name": "test-model", + "megatron_cfg": {"some_config": "value"}, + "hf_config_overrides": None, + } + + handle_model_import( + config, "test-model", pretrained_path, pt_checkpoint_exists=False + ) + + mock_import.assert_called_once_with( + "test-model", + pretrained_path, + {"some_config": "value"}, + ) + + @patch("nemo_rl.models.megatron.setup.import_model_from_hf_name") + @patch("nemo_rl.models.megatron.setup.parallel_state") + def test_reinitialize_parallel_state_after_import( + self, mock_ps, mock_import, tmp_path, capsys + ): + """Test that parallel state is destroyed after model import.""" + from nemo_rl.models.megatron.setup import handle_model_import + + mock_ps.model_parallel_is_initialized.return_value = True + + pretrained_path = str(tmp_path / "model") + config = { + "model_name": "test-model", + "megatron_cfg": {}, + "hf_config_overrides": {}, + } + + handle_model_import( + config, "test-model", pretrained_path, pt_checkpoint_exists=False + ) + + mock_ps.destroy_model_parallel.assert_called_once() + + captured = capsys.readouterr() + assert "Reinitializing model parallel" in captured.out + + +@pytest.mark.mcore +class TestSetupModelAndOptimizer: + """Tests for setup_model_and_optimizer function.""" + + @patch("nemo_rl.models.megatron.setup.GlobalState") + @patch("nemo_rl.models.megatron.setup.initialize_megatron") + @patch("nemo_rl.models.megatron.setup.set_jit_fusion_options") + @patch("nemo_rl.models.megatron.setup.init_checkpointing_context") + @patch("nemo_rl.models.megatron.setup.build_tokenizer") + @patch("nemo_rl.models.megatron.setup.get_model") + @patch("nemo_rl.models.megatron.setup.setup_optimizer") + @patch("nemo_rl.models.megatron.setup.checkpoint_exists") + @patch("nemo_rl.models.megatron.setup.MoEFloat16Module") + @patch("torch.distributed.all_reduce") + @patch("torch.distributed.barrier") + @patch("torch.tensor") + def test_setup_with_param_sync_and_frozen_moe_router( + self, + mock_tensor, + mock_barrier, + mock_all_reduce, + mock_custom_float16, + mock_checkpoint_exists, + mock_setup_optimizer, + mock_get_model, + mock_build_tokenizer, + mock_init_ckpt_context, + mock_set_jit, + mock_init_megatron, + mock_global_state, + ): + """Test setup_model_and_optimizer with MoE router freezing.""" + from nemo_rl.models.megatron.setup import setup_model_and_optimizer + + # Setup mocks + mock_state = MagicMock() + mock_state.start_time = 0.0 + mock_global_state.return_value = mock_state + + mock_megatron_cfg = MagicMock() + mock_megatron_cfg.ft = None + mock_megatron_cfg.model.vocab_size = 32000 + mock_megatron_cfg.model.make_vocab_size_divisible_by = 128 + mock_megatron_cfg.model.tensor_model_parallel_size = 1 + # Enable param gather overlap + mock_megatron_cfg.ddp.overlap_param_gather = True + mock_megatron_cfg.ddp.align_param_gather = True + mock_megatron_cfg.checkpoint.load = None + mock_megatron_cfg.checkpoint.pretrained_checkpoint = None + + mock_model_chunk = MagicMock() + mock_model_chunk.start_param_sync = MagicMock() + mock_model = [mock_model_chunk] + mock_get_model.return_value = mock_model + + mock_optimizer = MagicMock() + mock_scheduler = MagicMock() + mock_setup_optimizer.return_value = (mock_optimizer, mock_scheduler) + + mock_tensor_instance = MagicMock() + mock_tensor_instance.item.return_value = 0.0 + mock_tensor.return_value = mock_tensor_instance + + mock_checkpoint_exists.return_value = False + + policy_cfg = { + "megatron_cfg": { + "freeze_moe_router": True, # Enable MoE router freezing + } + } + + result = setup_model_and_optimizer( + policy_cfg=policy_cfg, + megatron_cfg=mock_megatron_cfg, + load_optimizer=True, + ) + + # Verify get_model was called (the mixed_precision_wrapper should be CustomFloat16Module) + mock_get_model.assert_called_once() + call_kwargs = mock_get_model.call_args[1] + # Check that pre_wrap_hook is not empty when freeze_moe_router is True + assert len(call_kwargs.get("pre_wrap_hook", [])) > 0 + + assert result.param_sync_func == mock_model_chunk.start_param_sync + + +@pytest.mark.mcore +class TestSetupReferenceModelState: + """Tests for setup_reference_model_state function.""" + + @patch("nemo_rl.models.megatron.setup.init_checkpointing_context") + @patch("nemo_rl.models.megatron.setup.GlobalState") + @patch("nemo_rl.models.megatron.setup.get_model") + @patch("nemo_rl.models.megatron.setup.checkpoint_exists") + @patch("nemo_rl.models.megatron.setup.load_checkpoint") + @patch("nemo_rl.models.megatron.setup.HAVE_FSDP2", False) + def test_setup_reference_model( + self, + mock_load_checkpoint, + mock_checkpoint_exists, + mock_get_model, + mock_global_state, + mock_init_ckpt_context, + capsys, + ): + """Test setup_reference_model_state when checkpoint exists.""" + from nemo_rl.models.megatron.setup import setup_reference_model_state + + # Setup mocks + mock_state = MagicMock() + mock_global_state.return_value = mock_state + + mock_megatron_cfg = MagicMock() + mock_megatron_cfg.dist.use_torch_fsdp2 = False + + # Create mock model with state dict + mock_model = MagicMock() + mock_model.state_dict.return_value = { + "layer1.weight": torch.tensor([1.0, 2.0]), + "layer1.bias": torch.tensor([0.1]), + } + mock_get_model.return_value = [mock_model] + + mock_checkpoint_exists.return_value = True + + config = { + "megatron_cfg": { + "freeze_moe_router": False, + } + } + + result = setup_reference_model_state( + config=config, + megatron_cfg=mock_megatron_cfg, + pretrained_path="/path/to/pretrained", + ) + + # Verify checkpoint was loaded + mock_load_checkpoint.assert_called_once() + + # Verify model was set to eval mode + mock_model.eval.assert_called_once() + + # Verify state dict is returned + assert isinstance(result, dict) + assert "layer1.weight" in result + assert "layer1.bias" in result + + # Verify tensors are on CPU + assert result["layer1.weight"].device.type == "cpu" + + captured = capsys.readouterr() + assert "Reference model loaded" in captured.out + + +@pytest.mark.mcore +class TestFinalizeMegatronSetup: + """Tests for finalize_megatron_setup function.""" + + @patch("nemo_rl.models.megatron.setup._update_model_config_funcs") + @patch("nemo_rl.models.megatron.setup.build_tokenizer") + @patch("nemo_rl.models.megatron.setup.AutoBridge") + def test_basic_finalize_setup( + self, + mock_auto_bridge, + mock_build_tokenizer, + mock_update_model_config, + ): + """Test basic finalize_megatron_setup.""" + from nemo_rl.models.megatron.setup import finalize_megatron_setup + + # Setup mocks + mock_megatron_cfg = MagicMock() + mock_megatron_cfg.model.make_vocab_size_divisible_by = 128 + + mock_model = MagicMock() + mock_optimizer = MagicMock() + + mock_worker_sharding = MagicMock() + mock_worker_sharding.get_axis_size.return_value = 4 # dp_size = 4 + + mock_tokenizer = MagicMock() + mock_build_tokenizer.return_value = mock_tokenizer + + mock_bridge = MagicMock() + mock_auto_bridge.from_hf_pretrained.return_value = mock_bridge + + config = { + "megatron_cfg": { + "tensor_model_parallel_size": 2, + "optimizer": { + "use_distributed_optimizer": False, + }, + "distributed_data_parallel_config": { + "overlap_param_gather": False, + }, + } + } + + result = finalize_megatron_setup( + config=config, + megatron_cfg=mock_megatron_cfg, + hf_model_name="test-model", + worker_sharding_annotations=mock_worker_sharding, + model=mock_model, + optimizer=mock_optimizer, + ) + + # Verify return values + megatron_tokenizer, megatron_bridge, should_disable_hook, dp_size = result + assert megatron_tokenizer == mock_tokenizer + assert megatron_bridge == mock_bridge + assert should_disable_hook is False + assert dp_size == 4 + + # Verify function calls + mock_update_model_config.assert_called_once() + mock_build_tokenizer.assert_called_once() + mock_auto_bridge.from_hf_pretrained.assert_called_once_with( + "test-model", trust_remote_code=True + )