diff --git a/docs/guides/eagle3-speculative-decoding.md b/docs/guides/eagle3-speculative-decoding.md new file mode 100644 index 0000000000..4194736982 --- /dev/null +++ b/docs/guides/eagle3-speculative-decoding.md @@ -0,0 +1,171 @@ +# Train with Eagle3 Speculative Decoding + +Eagle3 speculative decoding speeds up rollout generation by running a smaller draft model in vLLM and having the policy model verify its proposals. In NeMo RL, you can either use a fixed Eagle3 draft model only for generation, or train that draft model online during RL so it stays aligned with the policy. + +This guide covers the NeMo RL-specific runtime and training path. For a high-level overview of speculative decoding, see [An Introduction to Speculative Decoding for Reducing Latency in AI Inference](https://developer.nvidia.com/blog/an-introduction-to-speculative-decoding-for-reducing-latency-in-ai-inference/). For GRPO fundamentals, see the [GRPO guide](grpo.md). For asynchronous rollout collection, see the [Async GRPO guide](async-grpo.md). + +## Offline vs Online + +- **Offline draft model**: vLLM uses a fixed Eagle3 checkpoint for speculative decoding, but the RL training loop does not update that draft model. +- **Online draft training**: NeMo RL attaches an Eagle3 draft model to the Megatron policy worker, trains it alongside the policy, and refits both policy and draft weights into vLLM. + +Use the offline path when you already have a good drafter and only want faster rollouts. Use the online path when the policy is changing during RL and you want the drafter to track those updates. + +## Draft Checkpoint + +For the best results, start from an Eagle checkpoint that was already pretrained as a draft model, then use NeMo RL's online draft loss to keep it aligned with the policy during RL. For training or adapting an Eagle checkpoint, see the [Model Optimizer speculative decoding example](https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/speculative_decoding/README.md). + +If you are using a separately trained Eagle checkpoint, make sure its `eagle_config.json` contains: + +```json +{ + "has_lm_head": true +} +``` + +NeMo RL now keeps a trainer-owned draft LM head. If the draft checkpoint contains +`lm_head.weight`, NeMo RL loads it into the draft model. If that weight is absent, +NeMo RL initializes the draft LM head from the current policy output layer instead. + +## Enablement + +### Generation Only + +```yaml +policy: + generation: + backend: "vllm" + vllm_kwargs: + speculative_config: + method: "eagle3" + model: /path/to/eagle3-draft + num_speculative_tokens: 3 +``` + +This enables Eagle3 in vLLM, but the trainer does not own or update the draft model. + +### Online Draft Training + +```yaml +policy: + megatron_cfg: + enabled: true + + draft: + enabled: true + model_name: ${policy.generation.vllm_kwargs.speculative_config.model} + loss_weight: 1.0 + + generation: + backend: "vllm" + vllm_kwargs: + speculative_config: + method: "eagle3" + model: /path/to/eagle3-draft + num_speculative_tokens: 3 + draft_tensor_parallel_size: 1 +``` + +> [!NOTE] +> Online draft training currently requires the Megatron backend. NeMo RL rejects `policy.draft.enabled=true` on the DTensor path. + +## How It Works + +### Rollout Path + +During generation, vLLM runs the Eagle3 drafter from `policy.generation.vllm_kwargs.speculative_config`. When `policy.draft.enabled=true`, NeMo RL refits both: + +- the policy weights into the main vLLM model +- the `draft.*` weights into the vLLM drafter + +That keeps the rollout drafter aligned with the latest RL-updated policy instead of a stale checkpoint. + +### Training Path + +During the policy forward pass, NeMo RL captures: + +- token input embeddings +- a small set of intermediate hidden states from auxiliary policy layers + +Those captured activations are the Eagle inputs. If `policy.draft.aux_layer_indices` is not set, NeMo RL chooses a default early/middle/late set of policy layers. The draft model then predicts logits in the same vocabulary space by reusing the policy LM head. + +### Draft Loss and Time-Step Alignment + +The draft loss compares draft logits against detached policy logits, but only after aligning both sides to the same next-token event. + +Suppose the policy input sequence is: + +```text +[BOS, The, cat, sat] +``` + +The policy forward pass produces hidden states and logits at those positions: + +```text +position: 0 1 2 3 +input token: [BOS] [The] [cat] [sat] +hidden state: h0 h1 h2 h3 +policy logits: p0 p1 p2 p3 +predicts: The cat sat EOS +``` + +For Eagle training, NeMo RL does not compare raw `p0, p1, p2, p3` directly to the raw draft output. Instead it shifts the draft inputs and teacher outputs so draft position `t` predicts the teacher distribution for position `t+1`. + +First, it rolls the captured input embeddings left by one token: + +```text +original embeddings: e(BOS) e(The) e(cat) e(sat) +shifted embeddings: e(The) e(cat) e(sat) - +``` + +Then it rolls the detached teacher logits left by one position: + +```text +original teacher logits: p0 p1 p2 p3 +rolled teacher logits: p1 p2 p3 - +teacher meaning: cat sat EOS - +``` + +So the aligned draft-training pairs become: + +```text +(h0, e(The)) -> p1 +(h1, e(cat)) -> p2 +(h2, e(sat)) -> p3 +``` + +In words: + +- use the hidden state at position `t` +- use the embedding of the token at position `t+1` +- predict the teacher distribution for position `t+1` + +After this alignment, the draft loss is: + +$$ +L_{draft} = \mathbb{E}*t \left[- \sum_v \mathrm{softmax}(z*{policy,t})*v \log \mathrm{softmax}(z*{draft,t})_v \right] +$$ + +Here `z_{policy,t}` and `z_{draft,t}` refer to the aligned tensors after rolling, truncation, and masking, not the raw unshifted outputs of the forward pass. + +This has the same student gradient as forward KL from the policy distribution to the draft distribution, up to the teacher entropy constant. The total training objective is: + +$$ +L_{total} = L_{policy} + \lambda \cdot L_{draft} +$$ + +where `lambda` is `policy.draft.loss_weight`. + +## Important Knobs + +- `policy.draft.enabled`: attach and train the Eagle draft model +- `policy.draft.model_name`: checkpoint used to initialize the draft model +- `policy.draft.loss_weight`: weight on the auxiliary draft loss +- `policy.draft.aux_layer_indices`: policy layers whose hidden states feed Eagle +- `policy.generation.vllm_kwargs.speculative_config.num_speculative_tokens`: number of speculative tokens proposed by vLLM + +## Notes + +- When online draft training is enabled, NeMo RL logs `draft_loss`. +- Resume checkpoints include the nested draft model state when `policy.draft.enabled=true`. +- If speculative decoding is enabled without trainer-owned draft weights, vLLM must load real draft weights at startup. When the trainer owns the draft model, the first refit pushes both policy and draft parameters. diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 9f377ae9cf..ee44cf0982 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -212,6 +212,13 @@ policy: env_vars: null + draft: + enabled: false + model_name: null + loss_weight: 0.1 + num_layers: null + aux_layer_indices: null + # See docs/design-docs/sequence-packing-and-dynamic-batching.md # for more details on dynamic batching and sequence packing. dynamic_batching: diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3.yaml new file mode 100644 index 0000000000..7aae458b33 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3.yaml @@ -0,0 +1,35 @@ +defaults: ./grpo-llama3.2-1b-instruct-1n4g-megatron.yaml + +checkpointing: + checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3 + +grpo: + num_prompts_per_step: 4 + num_generations_per_prompt: 4 + +policy: + train_global_batch_size: 16 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 256 + sequence_packing: + enabled: false + draft: + enabled: true + model_name: "__set_nrl_eagle3_draft_model__" + loss_weight: 1.0 + generation: + max_new_tokens: 256 + vllm_cfg: + max_model_len: 256 + vllm_kwargs: + speculative_config: + method: "eagle3" + model: ${policy.draft.model_name} + num_speculative_tokens: 3 + draft_tensor_parallel_size: 1 + +logger: + log_dir: logs/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3 + wandb: + name: grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3 diff --git a/examples/run_grpo.py b/examples/run_grpo.py index 6130b99018..b8f6025067 100644 --- a/examples/run_grpo.py +++ b/examples/run_grpo.py @@ -84,8 +84,11 @@ def main() -> None: assert config["policy"]["generation"] is not None, ( "A generation config is required for GRPO" ) + has_refit_draft_weights = bool(config["policy"]["draft"]["enabled"]) config["policy"]["generation"] = configure_generation_config( - config["policy"]["generation"], tokenizer + config["policy"]["generation"], + tokenizer, + has_refit_draft_weights=has_refit_draft_weights, ) # setup data diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index e550429ce2..02e43ae659 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -2118,6 +2118,8 @@ def grpo_train( print("\nšŸ“Š Training Results:") print(f" • Loss: {metrics['loss']:.4f}") + if "draft_loss" in metrics: + print(f" • Draft Loss: {metrics['draft_loss']:.4f}") print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}") if master_config["grpo"]["use_dynamic_sampling"]: print(f" • Avg Filtered Reward: {np.mean(rewards.numpy()):.4f}") @@ -3130,6 +3132,8 @@ def async_grpo_train( print("\nšŸ“Š Training Results:") print(f" • Loss: {metrics['loss']:.4f}") + if "draft_loss" in metrics: + print(f" • Draft Loss: {metrics['draft_loss']:.4f}") print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}") print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") print(f" • Buffer Size: {buffer_size_current}") diff --git a/nemo_rl/algorithms/loss/__init__.py b/nemo_rl/algorithms/loss/__init__.py index a2d404fdaa..ede13323b1 100644 --- a/nemo_rl/algorithms/loss/__init__.py +++ b/nemo_rl/algorithms/loss/__init__.py @@ -22,6 +22,7 @@ DPOLossConfig, DPOLossDataDict, DPOLossFn, + DraftCrossEntropyLossFn, NLLLossFn, PreferenceLossDataDict, PreferenceLossFn, @@ -31,6 +32,7 @@ prepare_packed_loss_input, ) from nemo_rl.algorithms.loss.wrapper import ( + DraftLossWrapper, SequencePackingFusionLossWrapper, SequencePackingLossWrapper, wrap_loss_fn_with_input_preparation, @@ -53,5 +55,6 @@ "prepare_packed_loss_input", "SequencePackingFusionLossWrapper", "SequencePackingLossWrapper", + "DraftLossWrapper", "wrap_loss_fn_with_input_preparation", ] diff --git a/nemo_rl/algorithms/loss/interfaces.py b/nemo_rl/algorithms/loss/interfaces.py index f1c0db3e35..abf0b89095 100644 --- a/nemo_rl/algorithms/loss/interfaces.py +++ b/nemo_rl/algorithms/loss/interfaces.py @@ -29,6 +29,7 @@ class LossInputType(enum.Enum): LOGIT = "logit" LOGPROB = "logprob" DISTILLATION = "distillation" + DRAFT = "draft" class LossFunction(Protocol): @@ -66,6 +67,7 @@ def __call__( - For LossInputType.LOGPROB: next_token_logprobs (torch.Tensor) - For LossInputType.LOGIT: logits (torch.Tensor) - For LossInputType.DISTILLATION: student_topk_logprobs, teacher_topk_logprobs, H_all (torch.Tensor) + - For LossInputType.DRAFT: teacher_logits, student_logits, mask (torch.Tensor) Returns: tuple: (loss, metrics) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index ab05586d7c..15c58c0353 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -12,17 +12,75 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, NotRequired, TypedDict, TypeVar +from typing import Any, NotRequired, Optional, TypedDict, TypeVar import torch from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType, LossType from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import DistributedCrossEntropy Tensor = TypeVar("Tensor", bound=torch.Tensor) +class DraftCrossEntropyLossConfig(TypedDict): + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] + + +class DraftCrossEntropyLossDataDict(TypedDict): + teacher_logits: Tensor + student_logits: Tensor + token_mask: Tensor + sample_mask: Tensor + student_vocab_indices: NotRequired[Tensor] + + +class DraftCrossEntropyLossFn(LossFunction): + """Compute the auxiliary soft-target cross-entropy used for draft-model training.""" + + loss_type = LossType.TOKEN_LEVEL + input_type = LossInputType.DRAFT + + def __init__( + self, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + ): + self.vocab_parallel_rank = vocab_parallel_rank + self.vocab_parallel_group = vocab_parallel_group + + def __call__( + self, + teacher_logits: Tensor, + student_logits: Tensor, + token_mask: Tensor, + data: BatchedDataDict[DraftCrossEntropyLossDataDict], + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + ) -> torch.Tensor: + """Reduce the masked per-token draft loss to a scalar.""" + if self.vocab_parallel_group is not None: + # Soft cross entropy matches the forward-KL student gradient. + per_token_loss = DistributedCrossEntropy.apply( + student_logits, + teacher_logits, + self.vocab_parallel_group, + False, + ) + else: + teacher_probs = torch.nn.functional.softmax(teacher_logits, dim=-1) + student_log_probs = torch.nn.functional.log_softmax(student_logits, dim=-1) + per_token_loss = -(teacher_probs * student_log_probs).sum(dim=-1) + + mask = token_mask * data["sample_mask"].unsqueeze(-1) + return masked_mean( + per_token_loss, + mask, + global_normalization_factor=global_valid_toks, + ) + + class ClippedPGLossConfig(TypedDict): reference_policy_kl_penalty: float reference_policy_kl_type: str diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index 15e19037df..6263c299c4 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -123,6 +123,35 @@ def prepare_loss_input( "teacher_topk_logprobs": teacher_topk_logprobs, "H_all": H_all, } + elif loss_fn.input_type == LossInputType.DRAFT: + from megatron.core.transformer.multi_token_prediction import roll_tensor + + teacher_logits = roll_tensor( + logits.detach(), + shifts=-1, + dims=1, + cp_group=context_parallel_group, + )[0] + if "t2d" in data: + t2d = data["t2d"].to(teacher_logits.device) + if vocab_parallel_group is not None: + from megatron.core.transformer.utils import ( + gather_from_tensor_model_parallel_region, + ) + + teacher_logits = gather_from_tensor_model_parallel_region( + teacher_logits, vocab_parallel_group + ) + teacher_logits = teacher_logits[:, :, t2d] + + token_mask = roll_tensor( + data["token_mask"], shifts=-1, dims=1, cp_group=context_parallel_group + )[0] + loss_input = { + "teacher_logits": teacher_logits, + "student_logits": data["student_logits"], + "token_mask": token_mask, + } else: raise ValueError(f"Unknown loss function input type: {loss_fn.input_type}") diff --git a/nemo_rl/algorithms/loss/wrapper.py b/nemo_rl/algorithms/loss/wrapper.py index e27095379a..318de43fd4 100644 --- a/nemo_rl/algorithms/loss/wrapper.py +++ b/nemo_rl/algorithms/loss/wrapper.py @@ -13,14 +13,18 @@ # limitations under the License. import math -from typing import Any, Callable, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar import torch import torch.distributed from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.algorithms.loss.loss_functions import DraftCrossEntropyLossFn from nemo_rl.distributed.batched_data_dict import BatchedDataDict +if TYPE_CHECKING: + pass + Tensor = TypeVar("Tensor", bound=torch.Tensor) @@ -227,6 +231,67 @@ def __call__( ) +class DraftLossWrapper: + """Combine policy loss with draft soft cross-entropy loss.""" + + def __init__( + self, + loss_fn: Callable[..., tuple[torch.Tensor, dict[str, Any]]], + prepare_fn: Callable[Any, Any], + data_dict: BatchedDataDict[Any], + loss_weight: float = 1.0, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + ): + self.loss_fn = loss_fn + self.prepare_fn = prepare_fn + self.data_dict = data_dict + self.loss_weight = loss_weight + self.vocab_parallel_rank = vocab_parallel_rank + self.vocab_parallel_group = vocab_parallel_group + self.context_parallel_group = context_parallel_group + self.draft_loss_fn = DraftCrossEntropyLossFn( + vocab_parallel_group=vocab_parallel_group + ) + + def __call__( + self, + next_token_logits: torch.Tensor, + data: BatchedDataDict[Any], + global_valid_seqs: torch.Tensor | None, + global_valid_toks: torch.Tensor | None, + **kwargs: Any, + ) -> tuple[torch.Tensor, dict[str, Any]]: + if global_valid_toks is None: + raise ValueError("global_valid_toks is required for DraftLossWrapper.") + policy_loss, metrics = self.loss_fn( + next_token_logits, + data, + global_valid_seqs, + global_valid_toks, + **kwargs, + ) + + loss_input, data = self.prepare_fn( + next_token_logits, + data, + self.draft_loss_fn, + self.vocab_parallel_group, + self.vocab_parallel_group, + self.context_parallel_group, + ) + draft_loss = self.draft_loss_fn( + data=data, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + **loss_input, + ) + combined_loss = policy_loss + self.loss_weight * draft_loss + metrics["draft_loss"] = float(draft_loss.detach().item()) + return combined_loss, metrics + + def wrap_loss_fn_with_input_preparation( next_token_logits: Tensor, data: BatchedDataDict[Any], diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 3c98cdbca1..dd46ff7a27 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -68,6 +68,44 @@ def _compute_distributed_log_softmax( return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype) +@torch.no_grad() +def _compute_distributed_softmax( + vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup +) -> torch.Tensor: + """Compute a stable distributed softmax across tensor parallel workers. + + Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L239 + + Args: + vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_length, vocab_size//TP] + where TP is the tensor parallel size. + group (torch.distributed.ProcessGroup): Process group for the all-reduce operations. + + Returns: + torch.Tensor: Softmax output with the same shape as input, normalized across the full vocabulary. + """ + logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True) + torch.distributed.all_reduce( + logits_max, + op=torch.distributed.ReduceOp.MAX, + group=group, + ) + + vocab_parallel_logits = vocab_parallel_logits - logits_max + + exp_logits = vocab_parallel_logits.exp_() + + sum_exp_logits = exp_logits.sum(-1, keepdim=True) + torch.distributed.all_reduce( + sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=group, + ) + exp_logits.div_(sum_exp_logits) + + return exp_logits + + class DistributedLogprob(torch.autograd.Function): """Custom autograd function for computing log probabilities in a distributed setting. @@ -152,6 +190,63 @@ def backward( return grad_input, None, None, None, None, None, None +class DistributedCrossEntropy(torch.autograd.Function): + """Compute soft-target cross entropy across TP-sharded vocab. + + This returns H(p_target, q_student), which matches forward KL up to the + target entropy constant. Backward propagates only through student logits. + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] + ctx: Any, + student_logits: torch.Tensor, + target_logits: torch.Tensor, + group: torch.distributed.ProcessGroup, + inference_only: bool = False, + ) -> torch.Tensor: + if student_logits.shape != target_logits.shape: + raise ValueError( + "student_logits and target_logits must have the same shape, " + f"got {student_logits.shape} and {target_logits.shape}." + ) + + target_probs = _compute_distributed_softmax( + target_logits.to(dtype=torch.float32), + group=group, + ) + student_log_probs = _compute_distributed_log_softmax( + student_logits.to(dtype=torch.float32), group=group + ) + # Reuse the log-softmax buffers to avoid extra full-vocab allocations. + local_cross_entropy = torch.einsum( + "...v,...v->...", target_probs, student_log_probs + ).neg_() + torch.distributed.all_reduce( + local_cross_entropy, + op=torch.distributed.ReduceOp.SUM, + group=group, + ) + + if not inference_only: + student_probs = student_log_probs.exp_() + ctx.save_for_backward(target_probs, student_probs) + + return local_cross_entropy.contiguous() + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None]: + grad_output = grad_outputs[0] + target_probs, student_probs = ctx.saved_tensors + + # d(H(p, q))/d(z_v) = q_v - p_v + grad_student = (student_probs - target_probs) * grad_output.unsqueeze(-1) + return grad_student, None, None, None + + class ChunkedDistributedLogprob(torch.autograd.Function): """Custom autograd function for computing log probabilities in a distributed setting. diff --git a/nemo_rl/models/generation/__init__.py b/nemo_rl/models/generation/__init__.py index 90575e77b1..8465112c6c 100644 --- a/nemo_rl/models/generation/__init__.py +++ b/nemo_rl/models/generation/__init__.py @@ -23,7 +23,10 @@ def configure_generation_config( - config: GenerationConfig, tokenizer: TokenizerType, is_eval=False + config: GenerationConfig, + tokenizer: TokenizerType, + is_eval: bool = False, + has_refit_draft_weights: bool = False, ) -> GenerationConfig: """Apply specific configurations to generation config.""" # tokenizer setting @@ -42,17 +45,17 @@ def configure_generation_config( config = cast(VllmConfig, config) # set load_format config["vllm_cfg"]["load_format"] = "auto" if is_eval else "dummy" - is_spec = "speculative_config" in config.get("vllm_kwargs", {}) - if is_spec: - # When speculative decoding is enabled but the draft model is not co-trained - # with the policy (i.e., no weight sync for the draft model), we must use - # load_format='auto' to load actual weights. Using 'dummy' would leave the - # draft model with random weights that never get updated. - warnings.warn( - "Speculative decoding is enabled. Setting vllm_cfg['load_format'] to 'auto'. " - "This may result in slower startup times as full model weights are loaded." - ) - config["vllm_cfg"]["load_format"] = "auto" + speculative_config = config.get("vllm_kwargs", {}).get("speculative_config") + if speculative_config: + # Speculative decoding needs real startup weights unless the draft + # weights will be pushed into vLLM during the initial refit. + if not is_eval and not has_refit_draft_weights: + warnings.warn( + "Speculative decoding is enabled without draft refit sync. " + "Setting vllm_cfg['load_format'] to 'auto' so the drafter does " + "not start from dummy weights." + ) + config["vllm_cfg"]["load_format"] = "auto" # Respect the skip_tokenizer_init setting from the config. VLMs for example, require this to be False. if "skip_tokenizer_init" not in config["vllm_cfg"]: diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 5d239fd902..b189f990df 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -122,6 +122,35 @@ def _maybe_process_fp8_kv_cache(self) -> None: target_device, ) + @staticmethod + def _split_policy_and_draft_weights( + weights: list[tuple[str, torch.Tensor]], + ) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]: + policy_weights = [] + draft_weights = [] + for key, tensor in weights: + if key.startswith("draft."): + draft_weights.append((key.removeprefix("draft."), tensor)) + else: + policy_weights.append((key, tensor)) + return policy_weights, draft_weights + + def _load_draft_weights( + self, draft_weights: list[tuple[str, torch.Tensor]] + ) -> None: + if not draft_weights: + return + + draft_owner = getattr(self.model_runner, "drafter", None) + draft_model = getattr(draft_owner, "model", None) if draft_owner else None + if draft_model is None: + print( + "[draft] Received draft weights but vLLM drafter is unavailable; skipping draft update." + ) + return + + draft_model.load_weights(weights=draft_weights) + @wrap_with_nvtx_name("vllm_internal_worker_extension/update_weights_via_ipc_zmq") def update_weights_via_ipc_zmq(self) -> bool: """Receive and update model weights via ZMQ IPC socket. @@ -131,6 +160,8 @@ def update_weights_via_ipc_zmq(self) -> bool: """ buffer = None weights = None + policy_weights = None + draft_weights = None try: self.maybe_init_zmq() @@ -176,11 +207,16 @@ def update_weights_via_ipc_zmq(self) -> bool: # Load weights into the model from nemo_rl.models.generation.vllm.quantization import fp8 + policy_weights, draft_weights = self._split_policy_and_draft_weights( + weights + ) if fp8.is_fp8_model(self.model_runner.vllm_config): # the fp8 load_weights additionally casts bf16 weights into fp8 - fp8.load_weights(weights, self.model_runner) + fp8.load_weights(policy_weights, self.model_runner) else: - self.model_runner.model.load_weights(weights=weights) + self.model_runner.model.load_weights(weights=policy_weights) + + self._load_draft_weights(draft_weights) torch.cuda.current_stream().synchronize() @@ -189,8 +225,10 @@ def update_weights_via_ipc_zmq(self) -> bool: # copied the data, Python may not garbage collect these view objects immediately. # If sender reuses the buffer before GC runs, old views would read corrupted data. # Explicit del ensures immediate cleanup before sending ACK. - del weights, buffer + del weights, policy_weights, draft_weights, buffer weights = None + policy_weights = None + draft_weights = None buffer = None self.zmq_socket.send(IPCProtocol.ACK.value.encode()) @@ -229,11 +267,17 @@ def _load_model_weights(weights, model_runner): """ from nemo_rl.models.generation.vllm.quantization import fp8 + policy_weights, draft_weights = self._split_policy_and_draft_weights( + weights + ) + if fp8.is_fp8_model(model_runner.vllm_config): # the fp8 load_weights additionally casts bf16 weights into fp8 - fp8.load_weights(weights, model_runner) + fp8.load_weights(policy_weights, model_runner) else: - model_runner.model.load_weights(weights=weights) + model_runner.model.load_weights(weights=policy_weights) + + self._load_draft_weights(draft_weights) load_model_weight_func = lambda x: _load_model_weights(x, self.model_runner) diff --git a/nemo_rl/models/megatron/config.py b/nemo_rl/models/megatron/config.py index 7a65ed1924..d55101ba2b 100644 --- a/nemo_rl/models/megatron/config.py +++ b/nemo_rl/models/megatron/config.py @@ -75,3 +75,4 @@ class ModelAndOptimizerState(NamedTuple): scheduler: OptimizerParamScheduler checkpointing_context: dict[str, Any] param_sync_func: Optional[Callable] + draft_model: Optional[MegatronModule] = None diff --git a/nemo_rl/models/megatron/draft/__init__.py b/nemo_rl/models/megatron/draft/__init__.py new file mode 100644 index 0000000000..85180148c6 --- /dev/null +++ b/nemo_rl/models/megatron/draft/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_rl.models.megatron.draft.eagle import EagleModel +from nemo_rl.models.megatron.draft.hidden_capture import ( + CapturedStates, + HiddenStateCapture, + get_capture_context, + get_eagle3_aux_hidden_state_layers, +) +from nemo_rl.models.megatron.draft.utils import ( + export_eagle_weights_to_hf, + load_hf_weights_to_eagle, +) + +__all__ = [ + "CapturedStates", + "HiddenStateCapture", + "get_capture_context", + "EagleModel", + "load_hf_weights_to_eagle", + "export_eagle_weights_to_hf", + "get_eagle3_aux_hidden_state_layers", +] diff --git a/nemo_rl/models/megatron/draft/eagle.py b/nemo_rl/models/megatron/draft/eagle.py new file mode 100644 index 0000000000..4dc94e84b0 --- /dev/null +++ b/nemo_rl/models/megatron/draft/eagle.py @@ -0,0 +1,124 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.models.common.embeddings import RotaryEmbedding +from megatron.core.transformer import MegatronModule, TransformerConfig +from megatron.core.transformer.utils import ( + ensure_metadata_has_dp_cp_group, + sharded_state_dict_default, +) +from modelopt.torch.speculative.plugins.megatron_eagle import EagleModule +from torch import Tensor + + +class EagleModel(MegatronModule): + def __init__(self, config: TransformerConfig): + super().__init__(config=config) + self.config = config + + rotary_pos_emb = RotaryEmbedding( + kv_channels=config.kv_channels, + rotary_percent=1.0, + rotary_interleaved=False, + seq_len_interpolation_factor=None, + rotary_base=getattr(config, "rotary_base", 10000), + rope_scaling=getattr(config, "rope_scaling", False), + rope_scaling_factor=getattr(config, "rope_scaling_factor", 8.0), + use_cpu_initialization=getattr( + config, + "use_cpu_initialization", + not torch.cuda.is_available(), + ), + ) + self.eagle_module = EagleModule( + config=config, rotary_pos_emb=rotary_pos_emb, bias=False + ) + + def sharded_state_dict( + self, + prefix: str = "", + sharded_offsets: Tuple[Tuple[int, int, int], ...] = (), + metadata: Optional[dict] = None, + ) -> ShardedStateDict: + """Override to fix a bug in modelopt < 0.42.0. + + In modelopt < 0.42.0, EagleTransformerBlock.sharded_state_dict omits + tp_group when calling sharded_state_dict_default for non-layer children + (e.g. final_layernorm). This causes make_sharded_tensors_for_checkpoint + to receive tp_group=None while dp_cp_group is set, so the + ``tp_group is None and dp_cp_group is None`` guard never fires, and + get_pg_rank(None)=0 is used for all TP ranks. With TP>1 and DP>1, two + ranks end up with replica_id=(0,0,0), triggering CheckpointingException. + """ + sd = super().sharded_state_dict( + prefix=prefix, sharded_offsets=sharded_offsets, metadata=metadata + ) + + decoder = self.eagle_module.decoder + if not hasattr(decoder, "layers"): + return sd + + metadata = ensure_metadata_has_dp_cp_group(metadata) + + # Regenerate all non-layer children of the decoder with the correct + # tp_group. EagleTransformerBlock asserts sharded_offsets=() so we + # always use () here too. + for name, module in decoder.named_children(): + if module is decoder.layers: + continue + child_prefix = f"{prefix}eagle_module.decoder.{name}." + for k in list(sd): + if k.startswith(child_prefix): + del sd[k] + sd.update( + sharded_state_dict_default( + module, + child_prefix, + (), + metadata, + tp_group=decoder.tp_group, + ) + ) + + return sd + + def forward( + self, + hidden_states: Tensor, + input_embeds: Tensor, + attention_mask: Optional[Tensor] = None, + bootstrap_hidden_states: bool = True, + ) -> Tensor: + if bootstrap_hidden_states: + hidden_states = self.eagle_module.fc(hidden_states)[0] + elif hidden_states.shape[-1] != self.config.hidden_size: + raise ValueError( + f"Expected hidden states with size {self.config.hidden_size} when " + f"`bootstrap_hidden_states=False`, got {hidden_states.shape[-1]}." + ) + + hidden_states, _ = self.eagle_module( + embeddings=input_embeds, + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + logits, _ = self.eagle_module.eagle_output_layer(hidden_states) + logits = logits.transpose(0, 1).contiguous() + return logits diff --git a/nemo_rl/models/megatron/draft/hidden_capture.py b/nemo_rl/models/megatron/draft/hidden_capture.py new file mode 100644 index 0000000000..f0a6b0c2a8 --- /dev/null +++ b/nemo_rl/models/megatron/draft/hidden_capture.py @@ -0,0 +1,323 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from typing import ContextManager, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from megatron.core import parallel_state +from megatron.training.utils import unwrap_model +from torch import Tensor, nn + + +def get_eagle3_aux_hidden_state_layers(num_layers: int) -> tuple[int, ...]: + """Pick the default auxiliary policy layers whose activations feed Eagle training.""" + candidate_indices = ( + 1, + max(0, num_layers // 2 - 1), + max(1, num_layers - 4), + ) + valid_indices = sorted(set(candidate_indices)) + return tuple(valid_indices) + + +_DTYPE_TO_CODE = { + torch.float16: 0, + torch.bfloat16: 1, + torch.float32: 2, +} + +_CODE_TO_DTYPE = {code: dtype for dtype, code in _DTYPE_TO_CODE.items()} + + +@dataclass +class CapturedStates: + """Container for hidden states captured from the policy model.""" + + hidden_states: Optional[Tensor] = None + inputs_embeds: Optional[Tensor] = None + + +class HiddenStateCapture: + """Capture policy embeddings and auxiliary hidden states for Eagle training.""" + + def __init__( + self, + model: nn.Module, + aux_layer_indices: Optional[Tuple[int, ...]] = None, + ): + self.model = unwrap_model(model) + self.num_layers = self.model.config.num_layers + + self.aux_layer_indices = ( + aux_layer_indices + if aux_layer_indices is not None + else get_eagle3_aux_hidden_state_layers(self.num_layers) + ) + + self.pp_size = parallel_state.get_pipeline_model_parallel_world_size() + self.pp_rank = parallel_state.get_pipeline_model_parallel_rank() + self.is_first_stage = parallel_state.is_pipeline_first_stage() + self.is_last_stage = parallel_state.is_pipeline_last_stage() + + self._global_to_local: Dict[int, int] = {} + self._local_aux_indices: List[int] = [] + self._compute_local_layer_mapping() + self._layer_owner_by_global_idx = self._compute_layer_owner_map() + + self._captured: Dict[str, Tensor] = {} + self._hooks: List[torch.utils.hooks.RemovableHandle] = [] + + def _compute_local_layer_mapping(self) -> None: + for local_idx, layer in enumerate(self.model.decoder.layers): + global_idx = int(layer.layer_number) - 1 + if global_idx in self.aux_layer_indices: + self._global_to_local[global_idx] = local_idx + self._local_aux_indices.append(local_idx) + + def _compute_layer_owner_map(self) -> Dict[int, int]: + if self.pp_size == 1 or not dist.is_initialized(): + return {layer_idx: 0 for layer_idx in range(self.num_layers)} + + pp_group = parallel_state.get_pipeline_model_parallel_group() + local_owner_mask = torch.zeros( + self.num_layers, + dtype=torch.int64, + device=torch.cuda.current_device(), + ) + for layer in self.model.decoder.layers: + global_idx = int(layer.layer_number) - 1 + if 0 <= global_idx < self.num_layers: + local_owner_mask[global_idx] = 1 + + gathered_owner_masks = [ + torch.zeros_like(local_owner_mask) for _ in range(self.pp_size) + ] + dist.all_gather(gathered_owner_masks, local_owner_mask, group=pp_group) + + owner_map: Dict[int, int] = {} + for global_idx in range(self.num_layers): + for rank_idx, owner_mask in enumerate(gathered_owner_masks): + if int(owner_mask[global_idx].item()) == 1: + owner_map[global_idx] = rank_idx + break + return owner_map + + def _make_layer_output_hook(self, global_idx: int): + def hook(_module, _args, output): + hidden_states = output[0] if isinstance(output, tuple) else output + if hidden_states is None: + return + self._captured[f"layer_{global_idx}"] = hidden_states.detach().clone() + + return hook + + def _make_embedding_hook(self): + def hook(_module, _args, output): + self._captured["embeds"] = output.detach().clone() + + return hook + + def register_hooks(self) -> None: + self.clear_hooks() + self._captured.clear() + + if self.is_first_stage and hasattr(self.model, "embedding"): + self._hooks.append( + self.model.embedding.register_forward_hook(self._make_embedding_hook()) + ) + + for local_idx in self._local_aux_indices: + layer = self.model.decoder.layers[local_idx] + global_idx = int(layer.layer_number) - 1 + self._hooks.append( + layer.register_forward_hook(self._make_layer_output_hook(global_idx)) + ) + + def clear_hooks(self) -> None: + for handle in self._hooks: + handle.remove() + self._hooks.clear() + + @contextmanager + def capture_context(self): + try: + self.register_hooks() + yield self + finally: + self.clear_hooks() + + def _assemble_local_states(self) -> CapturedStates: + embeds = self._captured.get("embeds") + + hidden_chunks = [] + for global_idx in sorted(self.aux_layer_indices): + tensor = self._captured.get(f"layer_{global_idx}") + if tensor is not None: + hidden_chunks.append(tensor) + + if not hidden_chunks: + return CapturedStates(hidden_states=None, inputs_embeds=embeds) + + return CapturedStates( + hidden_states=torch.cat(hidden_chunks, dim=-1), + inputs_embeds=embeds, + ) + + def _owner_rank_for_global_layer(self, global_layer_idx: int) -> int: + if self.pp_size == 1: + return 0 + if global_layer_idx in self._layer_owner_by_global_idx: + return self._layer_owner_by_global_idx[global_layer_idx] + layers_per_rank = max(1, self.num_layers // self.pp_size) + return min(global_layer_idx // layers_per_rank, self.pp_size - 1) + + @staticmethod + def _send_tensor( + tensor: Tensor, + dst_rank: int, + group: dist.ProcessGroup, + ) -> None: + dtype_code = _DTYPE_TO_CODE.get(tensor.dtype) + if dtype_code is None: + raise ValueError(f"Unsupported tensor dtype for send/recv: {tensor.dtype}") + + metadata = torch.tensor( + [tensor.shape[0], tensor.shape[1], tensor.shape[2], dtype_code], + dtype=torch.int64, + device=tensor.device, + ) + dist.send(metadata, dst=dst_rank, group=group) + dist.send(tensor.contiguous(), dst=dst_rank, group=group) + + @staticmethod + def _recv_tensor( + src_rank: int, + group: dist.ProcessGroup, + device: torch.device, + ) -> Tensor: + metadata = torch.empty(4, dtype=torch.int64, device=device) + dist.recv(metadata, src=src_rank, group=group) + seq_len, batch_size, hidden_size, dtype_code = [ + int(x) for x in metadata.tolist() + ] + dtype = _CODE_TO_DTYPE.get(dtype_code) + if dtype is None: + raise ValueError( + f"Unsupported tensor dtype code in send/recv: {dtype_code}" + ) + + received = torch.empty( + seq_len, + batch_size, + hidden_size, + dtype=dtype, + device=device, + ) + dist.recv(received, src=src_rank, group=group) + return received + + def _gather_distributed(self) -> CapturedStates: + pp_group = parallel_state.get_pipeline_model_parallel_group() + last_rank = self.pp_size - 1 + recv_device = torch.device("cuda", torch.cuda.current_device()) + + sample_tensor = None + for tensor in self._captured.values(): + if tensor is not None: + sample_tensor = tensor + break + + if sample_tensor is None and not self.is_last_stage: + return CapturedStates() + + gathered_hidden_by_layer: Dict[int, Tensor] = {} + + for global_idx in self.aux_layer_indices: + owner_rank = self._owner_rank_for_global_layer(global_idx) + key = f"layer_{global_idx}" + + if self.pp_rank == owner_rank: + layer_tensor = self._captured.get(key) + if layer_tensor is None: + continue + if self.is_last_stage: + gathered_hidden_by_layer[global_idx] = layer_tensor + else: + self._send_tensor(layer_tensor, dst_rank=last_rank, group=pp_group) + elif self.is_last_stage: + received = self._recv_tensor( + src_rank=owner_rank, + group=pp_group, + device=recv_device, + ) + gathered_hidden_by_layer[global_idx] = received + + gathered_embeds = None + if self.is_first_stage: + embeds = self._captured.get("embeds") + if embeds is not None: + if self.is_last_stage: + gathered_embeds = embeds + else: + self._send_tensor(embeds, dst_rank=last_rank, group=pp_group) + elif self.is_last_stage: + gathered_embeds = self._recv_tensor( + src_rank=0, + group=pp_group, + device=recv_device, + ) + + if not self.is_last_stage: + return CapturedStates() + + if gathered_hidden_by_layer: + hidden_states = torch.cat( + [ + gathered_hidden_by_layer[layer] + for layer in sorted(gathered_hidden_by_layer.keys()) + ], + dim=-1, + ) + else: + hidden_states = None + + return CapturedStates( + hidden_states=hidden_states, + inputs_embeds=gathered_embeds, + ) + + def get_captured_states(self) -> CapturedStates: + if self.pp_size == 1: + return self._assemble_local_states() + return self._gather_distributed() + + +def get_capture_context( + model: nn.Module, + enabled: bool = False, + aux_layer_indices: Optional[Tuple[int, ...]] = None, +) -> Tuple[ContextManager, Optional[HiddenStateCapture]]: + """Return a no-op context unless draft training needs hidden-state capture for this step.""" + if not enabled: + return nullcontext(), None + capture = HiddenStateCapture( + model=model, + aux_layer_indices=aux_layer_indices, + ) + return capture.capture_context(), capture diff --git a/nemo_rl/models/megatron/draft/utils.py b/nemo_rl/models/megatron/draft/utils.py new file mode 100644 index 0000000000..3d4e129173 --- /dev/null +++ b/nemo_rl/models/megatron/draft/utils.py @@ -0,0 +1,1310 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Mapping + +import torch +import torch.distributed as dist +from megatron.core import parallel_state +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer import MegatronModule, TransformerConfig +from megatron.training.utils import unwrap_model +from torch import Tensor + +StateDict = dict[str, Tensor] +CheckpointLoader = Callable[[Path], StateDict] + +_CHECKPOINT_CANDIDATE_NAMES = ( + "model.safetensors", + "model.safetensors.index.json", + "pytorch_model.bin", + "pytorch_model.bin.index.json", +) +_HF_SNAPSHOT_ALLOW_PATTERNS = [ + "model.safetensors", + "model-*.safetensors", + "model.safetensors.index.json", + "pytorch_model.bin", + "pytorch_model-*.bin", + "pytorch_model.bin.index.json", +] +_HF_SNAPSHOT_IGNORE_PATTERNS = ["*.pt", "*.pth", "*.ckpt"] +_MODEL_LAYER_QKV_KEY_PATTERN = re.compile( + r"^eagle_module\.decoder\.layers\.(\d+)\.self_attention\.linear_qkv\.weight$" +) +_CHECKPOINT_LAYER_KEY_PATTERN = re.compile(r"^layers\.(\d+)\.(.+)$") + + +@dataclass(frozen=True) +class _EagleLayerLayout: + layer_index: int + model_prefix: str + checkpoint_prefix: str + hidden_norm_key: str | None + input_layernorm_key: str | None + post_attention_layernorm_key: str | None + + @property + def qkv_weight_key(self) -> str: + return f"{self.model_prefix}.self_attention.linear_qkv.weight" + + @property + def proj_weight_key(self) -> str: + return f"{self.model_prefix}.self_attention.linear_proj.weight" + + @property + def fc1_weight_key(self) -> str: + return f"{self.model_prefix}.mlp.linear_fc1.weight" + + @property + def fc2_weight_key(self) -> str: + return f"{self.model_prefix}.mlp.linear_fc2.weight" + + +def _resolve_optional_key( + model_keys: set[str], + *candidates: str | None, +) -> str | None: + for candidate in candidates: + if candidate is not None and candidate in model_keys: + return candidate + return None + + +@dataclass(frozen=True) +class _EagleModelLayout: + layers: tuple[_EagleLayerLayout, ...] + final_norm_key: str | None + lm_head_key: str | None + + @classmethod + def detect(cls, model_state: Mapping[str, Tensor]) -> _EagleModelLayout: + model_keys = set(model_state) + layer_indices = sorted( + int(match.group(1)) + for key in model_keys + if (match := _MODEL_LAYER_QKV_KEY_PATTERN.match(key)) is not None + ) + + if layer_indices: + layer_prefixes = { + layer_index: f"eagle_module.decoder.layers.{layer_index}" + for layer_index in layer_indices + } + elif "eagle_module.layer.self_attention.linear_qkv.weight" in model_keys: + layer_prefixes = {0: "eagle_module.layer"} + else: + raise RuntimeError( + "Unable to detect Eagle layer prefix from model state dict." + ) + + final_norm_key = _resolve_optional_key( + model_keys, + "eagle_module.decoder.final_layernorm.weight", + "eagle_module.norm.weight", + ) + lm_head_key = _resolve_optional_key( + model_keys, + "eagle_module.eagle_output_layer.weight", + "eagle_module.lm_head.weight", + ) + global_hidden_norm_key = _resolve_optional_key( + model_keys, + "eagle_module.hidden_norm.weight", + "eagle_module.hnorm.weight", + "eagle_module.pre_fc_norm_hidden.weight", + "eagle_module.enorm.weight", + ) + + use_midlayer_alias = len(layer_prefixes) == 1 and 0 in layer_prefixes + layers = tuple( + _EagleLayerLayout( + layer_index=layer_index, + model_prefix=layer_prefix, + checkpoint_prefix=( + "midlayer" if use_midlayer_alias else f"layers.{layer_index}" + ), + hidden_norm_key=_resolve_optional_key( + model_keys, + f"{layer_prefix}.hidden_norm.weight", + f"{layer_prefix}.hnorm.weight", + f"{layer_prefix}.pre_fc_norm_hidden.weight", + global_hidden_norm_key if layer_index == 0 else None, + ), + input_layernorm_key=_resolve_optional_key( + model_keys, + f"{layer_prefix}.input_layernorm.weight", + f"{layer_prefix}.self_attention.linear_qkv.layer_norm_weight", + ), + post_attention_layernorm_key=_resolve_optional_key( + model_keys, + f"{layer_prefix}.pre_mlp_layernorm.weight", + f"{layer_prefix}.mlp.linear_fc1.layer_norm_weight", + ), + ) + for layer_index, layer_prefix in sorted(layer_prefixes.items()) + ) + + return cls( + layers=layers, + final_norm_key=final_norm_key, + lm_head_key=lm_head_key, + ) + + @property + def layer_by_index(self) -> dict[int, _EagleLayerLayout]: + return {layer.layer_index: layer for layer in self.layers} + + +def _combine_or_shard_weight_parts( + *, + parameter_name: str, + fused_weight: Tensor | None, + component_weights: tuple[Tensor | None, ...], + target: Tensor | None, + tp_rank: int, + incomplete_error: str, +) -> Tensor | None: + if fused_weight is not None: + return fused_weight + + if not any(weight is not None for weight in component_weights): + return None + if any(weight is None for weight in component_weights): + raise RuntimeError(incomplete_error) + + full_weight = torch.cat( + [weight for weight in component_weights if weight is not None], + dim=0, + ).contiguous() + if target is None: + return full_weight + if full_weight.shape == target.shape: + return full_weight.to(dtype=target.dtype) + + full_dim = full_weight.shape[0] + local_dim = target.shape[0] + if local_dim <= 0 or full_dim % local_dim != 0: + raise RuntimeError( + f"[draft] Cannot infer TP sharding for '{parameter_name}': " + f"checkpoint={tuple(full_weight.shape)} model={tuple(target.shape)}" + ) + + inferred_tp = full_dim // local_dim + if tp_rank >= inferred_tp: + raise RuntimeError( + f"[draft] tp_rank={tp_rank} out of range for key '{parameter_name}' " + f"(inferred_tp={inferred_tp})" + ) + + # Fused Megatron weights expect each local TP shard to preserve component + # boundaries, e.g. [q_local, k_local, v_local] instead of chunk(full[q, k, v]). + local_weight_parts = [] + for weight in component_weights: + assert weight is not None + if weight.shape[0] % inferred_tp != 0: + raise RuntimeError( + f"[draft] Cannot TP-shard fused component for '{parameter_name}': " + f"component={tuple(weight.shape)} inferred_tp={inferred_tp}" + ) + local_weight_parts.append( + torch.chunk(weight, inferred_tp, dim=0)[tp_rank].contiguous() + ) + + local_weight = torch.cat(local_weight_parts, dim=0).contiguous() + if local_weight.shape != target.shape: + raise RuntimeError( + f"[draft] Invalid TP shard shape for '{parameter_name}': " + f"got={tuple(local_weight.shape)} expected={tuple(target.shape)}" + ) + return local_weight.to(dtype=target.dtype) + + +@dataclass +class _PendingLayerWeights: + qkv_weight: Tensor | None = None + q_weight: Tensor | None = None + k_weight: Tensor | None = None + v_weight: Tensor | None = None + fc1_weight: Tensor | None = None + gate_weight: Tensor | None = None + up_weight: Tensor | None = None + + def apply_to( + self, + mapped_state: StateDict, + layer: _EagleLayerLayout, + model_state: Mapping[str, Tensor], + tp_rank: int, + ) -> None: + qkv_weight = _combine_or_shard_weight_parts( + parameter_name=layer.qkv_weight_key, + fused_weight=self.qkv_weight, + component_weights=(self.q_weight, self.k_weight, self.v_weight), + target=model_state.get(layer.qkv_weight_key), + tp_rank=tp_rank, + incomplete_error=( + "[draft] Incomplete QKV tensors. Expected q_proj, k_proj, and v_proj." + ), + ) + if qkv_weight is not None: + mapped_state[layer.qkv_weight_key] = qkv_weight + + fc1_weight = _combine_or_shard_weight_parts( + parameter_name=layer.fc1_weight_key, + fused_weight=self.fc1_weight, + component_weights=(self.gate_weight, self.up_weight), + target=model_state.get(layer.fc1_weight_key), + tp_rank=tp_rank, + incomplete_error=( + "[draft] Incomplete MLP tensors. Expected gate_proj and up_proj." + ), + ) + if fc1_weight is not None: + mapped_state[layer.fc1_weight_key] = fc1_weight + + +def _get_num_aux_hidden_states(config: TransformerConfig) -> int: + aux_layer_ids = getattr(config, "eagle_aux_hidden_state_layer_ids", None) + if aux_layer_ids: + return len(aux_layer_ids) + if getattr(config, "use_aux_hidden_state", True): + return 3 + return 0 + + +def _all_gather_tp_shards(local_weight: Tensor) -> list[Tensor]: + if ( + not parallel_state.model_parallel_is_initialized() + or not dist.is_available() + or not dist.is_initialized() + ): + return [local_weight] + + tp_group = parallel_state.get_tensor_model_parallel_group() + tp_world_size = parallel_state.get_tensor_model_parallel_world_size() + if tp_world_size == 1: + return [local_weight] + + gathered = [torch.empty_like(local_weight) for _ in range(tp_world_size)] + dist.all_gather(gathered, local_weight.contiguous(), group=tp_group) + return gathered + + +def _gather_tp_qkv_weight( + local_fused_weight: Tensor, + q_dim: int, + kv_dim: int, +) -> tuple[Tensor, Tensor, Tensor]: + shards = _all_gather_tp_shards(local_fused_weight) + if len(shards) == 1 and local_fused_weight.shape[0] == q_dim + 2 * kv_dim: + return local_fused_weight.split([q_dim, kv_dim, kv_dim], dim=0) + + tp_world_size = len(shards) + if q_dim % tp_world_size != 0 or kv_dim % tp_world_size != 0: + raise RuntimeError( + "QKV dimensions are not divisible by the tensor-parallel world size." + ) + + q_shards = [] + k_shards = [] + v_shards = [] + local_q_dim = q_dim // tp_world_size + local_kv_dim = kv_dim // tp_world_size + for shard in shards: + q_local, k_local, v_local = shard.split( + [local_q_dim, local_kv_dim, local_kv_dim], + dim=0, + ) + q_shards.append(q_local) + k_shards.append(k_local) + v_shards.append(v_local) + + return ( + torch.cat(q_shards, dim=0).contiguous(), + torch.cat(k_shards, dim=0).contiguous(), + torch.cat(v_shards, dim=0).contiguous(), + ) + + +def _gather_tp_gate_up_weight( + local_fused_weight: Tensor, + ffn_hidden_size: int, +) -> tuple[Tensor, Tensor]: + shards = _all_gather_tp_shards(local_fused_weight) + if len(shards) == 1 and local_fused_weight.shape[0] == 2 * ffn_hidden_size: + return local_fused_weight.split([ffn_hidden_size, ffn_hidden_size], dim=0) + + tp_world_size = len(shards) + if ffn_hidden_size % tp_world_size != 0: + raise RuntimeError( + "ffn_hidden_size is not divisible by the tensor-parallel world size." + ) + + gate_shards = [] + up_shards = [] + local_ffn_hidden_size = ffn_hidden_size // tp_world_size + for shard in shards: + gate_local, up_local = shard.split( + [local_ffn_hidden_size, local_ffn_hidden_size], + dim=0, + ) + gate_shards.append(gate_local) + up_shards.append(up_local) + + return ( + torch.cat(gate_shards, dim=0).contiguous(), + torch.cat(up_shards, dim=0).contiguous(), + ) + + +def _gather_tp_weight_if_needed( + local_weight: Tensor, + expected_shape_or_tp_group: tuple[int, ...] | dist.ProcessGroup | None, + split_axis: int | None = None, +) -> Tensor: + if split_axis is None: + tp_group = expected_shape_or_tp_group + if tp_group is None or not dist.is_available() or not dist.is_initialized(): + return local_weight + + tp_world_size = dist.get_world_size(tp_group) + if tp_world_size <= 1: + return local_weight + + gathered = [torch.empty_like(local_weight) for _ in range(tp_world_size)] + dist.all_gather(gathered, local_weight.contiguous(), group=tp_group) + return torch.cat(gathered, dim=0).contiguous() + + expected_shape = expected_shape_or_tp_group + if not isinstance(expected_shape, tuple): + raise TypeError( + "expected_shape_or_tp_group must be a shape tuple when split_axis is set." + ) + if tuple(local_weight.shape) == expected_shape: + return local_weight + + shards = _all_gather_tp_shards(local_weight) + if len(shards) == 1: + return local_weight + return torch.cat(shards, dim=split_axis).contiguous() + + +def _extract_tensor_state_dict( + checkpoint_obj: object, + checkpoint_path: Path, +) -> StateDict: + if ( + isinstance(checkpoint_obj, dict) + and "state_dict" in checkpoint_obj + and isinstance(checkpoint_obj["state_dict"], dict) + ): + checkpoint_obj = checkpoint_obj["state_dict"] + + if not isinstance(checkpoint_obj, dict): + raise RuntimeError( + f"[draft] Unsupported checkpoint payload in '{checkpoint_path}'. " + "Expected a state dict or a dict containing `state_dict`." + ) + + state_dict = { + key: value + for key, value in checkpoint_obj.items() + if isinstance(key, str) and isinstance(value, Tensor) + } + if not state_dict: + raise RuntimeError( + f"[draft] Checkpoint '{checkpoint_path}' did not contain any tensors." + ) + return state_dict + + +def _load_safetensors_file(checkpoint_path: Path) -> StateDict: + from safetensors.torch import load_file as load_safetensors + + return _extract_tensor_state_dict( + load_safetensors(str(checkpoint_path)), + checkpoint_path, + ) + + +def _load_torch_file(checkpoint_path: Path) -> StateDict: + try: + checkpoint_obj = torch.load( + str(checkpoint_path), + map_location="cpu", + weights_only=True, + ) + except TypeError: + checkpoint_obj = torch.load( + str(checkpoint_path), + map_location="cpu", + ) + + return _extract_tensor_state_dict(checkpoint_obj, checkpoint_path) + + +def _merge_checkpoint_shards( + checkpoint_dir: Path, + shard_names: list[str], + shard_loader: CheckpointLoader, + source_name: str, +) -> StateDict: + merged_state: StateDict = {} + + for shard_name in shard_names: + shard_path = checkpoint_dir / shard_name + if not shard_path.exists(): + raise FileNotFoundError( + f"[draft] Missing shard '{shard_name}' referenced by '{source_name}'." + ) + + shard_state = shard_loader(shard_path) + duplicate_keys = set(merged_state).intersection(shard_state) + if duplicate_keys: + duplicate_preview = ", ".join(sorted(duplicate_keys)[:5]) + raise RuntimeError( + f"[draft] Duplicate keys found while merging '{source_name}': " + f"{duplicate_preview}" + ) + merged_state.update(shard_state) + + return merged_state + + +def _load_index_checkpoint(index_path: Path) -> StateDict: + with index_path.open() as handle: + try: + index_data = json.load(handle) + except json.JSONDecodeError as exc: + raise RuntimeError( + f"[draft] Failed to parse checkpoint index '{index_path}'." + ) from exc + + weight_map = index_data.get("weight_map") + if not isinstance(weight_map, dict) or not weight_map: + raise RuntimeError( + f"[draft] Checkpoint index '{index_path}' does not contain a valid " + "`weight_map`." + ) + + shard_names = sorted( + { + shard_name + for shard_name in weight_map.values() + if isinstance(shard_name, str) + } + ) + if not shard_names: + raise RuntimeError( + f"[draft] Checkpoint index '{index_path}' does not reference any " + "weight shards." + ) + + if index_path.name == "model.safetensors.index.json": + return _merge_checkpoint_shards( + index_path.parent, + shard_names, + _load_safetensors_file, + index_path.name, + ) + if index_path.name == "pytorch_model.bin.index.json": + return _merge_checkpoint_shards( + index_path.parent, + shard_names, + _load_torch_file, + index_path.name, + ) + + raise RuntimeError( + f"[draft] Unsupported checkpoint index format '{index_path.name}'." + ) + + +def _load_checkpoint_file(checkpoint_path: Path) -> StateDict: + if ( + checkpoint_path.name.startswith("model-") + and checkpoint_path.suffix == ".safetensors" + ): + companion_index = checkpoint_path.parent / "model.safetensors.index.json" + if companion_index.exists(): + return _load_index_checkpoint(companion_index) + + sibling_shards = sorted( + shard_path.name + for shard_path in checkpoint_path.parent.glob("model-*.safetensors") + ) + if len(sibling_shards) > 1: + return _merge_checkpoint_shards( + checkpoint_path.parent, + sibling_shards, + _load_safetensors_file, + str(checkpoint_path.parent), + ) + + if ( + checkpoint_path.name.startswith("pytorch_model-") + and checkpoint_path.suffix == ".bin" + ): + companion_index = checkpoint_path.parent / "pytorch_model.bin.index.json" + if companion_index.exists(): + return _load_index_checkpoint(companion_index) + + sibling_shards = sorted( + shard_path.name + for shard_path in checkpoint_path.parent.glob("pytorch_model-*.bin") + ) + if len(sibling_shards) > 1: + return _merge_checkpoint_shards( + checkpoint_path.parent, + sibling_shards, + _load_torch_file, + str(checkpoint_path.parent), + ) + + if checkpoint_path.suffix == ".safetensors": + return _load_safetensors_file(checkpoint_path) + if checkpoint_path.suffix == ".bin": + return _load_torch_file(checkpoint_path) + if checkpoint_path.name.endswith(".index.json"): + return _load_index_checkpoint(checkpoint_path) + + raise RuntimeError( + f"[draft] Unsupported checkpoint file '{checkpoint_path}'. Expected " + "a `.safetensors`, `.bin`, or `.index.json` file." + ) + + +def _load_checkpoint_from_directory(checkpoint_dir: Path) -> StateDict: + for candidate_name in _CHECKPOINT_CANDIDATE_NAMES: + candidate_path = checkpoint_dir / candidate_name + if candidate_path.exists(): + return _load_checkpoint_file(candidate_path) + + safetensor_shards = sorted( + shard_path.name for shard_path in checkpoint_dir.glob("model-*.safetensors") + ) + if safetensor_shards: + return _merge_checkpoint_shards( + checkpoint_dir, + safetensor_shards, + _load_safetensors_file, + str(checkpoint_dir), + ) + + torch_shards = sorted( + shard_path.name for shard_path in checkpoint_dir.glob("pytorch_model-*.bin") + ) + if torch_shards: + return _merge_checkpoint_shards( + checkpoint_dir, + torch_shards, + _load_torch_file, + str(checkpoint_dir), + ) + + raise FileNotFoundError( + f"[draft] No supported checkpoint files were found in '{checkpoint_dir}'." + ) + + +def _load_checkpoint_state(checkpoint_source: str) -> StateDict: + source_path = Path(checkpoint_source) + if source_path.is_file(): + return _load_checkpoint_file(source_path) + if source_path.is_dir(): + return _load_checkpoint_from_directory(source_path) + + try: + from huggingface_hub import snapshot_download + + source_path = Path( + snapshot_download( + repo_id=checkpoint_source, + allow_patterns=_HF_SNAPSHOT_ALLOW_PATTERNS, + ignore_patterns=_HF_SNAPSHOT_IGNORE_PATTERNS, + ) + ) + except Exception as exc: + raise FileNotFoundError( + f"[draft] Could not resolve '{checkpoint_source}' as a local checkpoint " + "path or Hugging Face repo." + ) from exc + + return _load_checkpoint_from_directory(source_path) + + +def _normalize_hf_key(raw_hf_key: str) -> str: + hf_key = raw_hf_key + prefixes = ("draft.", "module.", "eagle_module.") + changed = True + while changed: + changed = False + for prefix in prefixes: + if hf_key.startswith(prefix): + hf_key = hf_key.removeprefix(prefix) + changed = True + return hf_key + + +def _parse_layer_checkpoint_key(hf_key: str) -> tuple[int, str] | None: + if hf_key.startswith("midlayer."): + return 0, hf_key.removeprefix("midlayer.") + + match = _CHECKPOINT_LAYER_KEY_PATTERN.match(hf_key) + if match is None: + return None + + return int(match.group(1)), match.group(2) + + +def _get_tp_rank() -> int: + if parallel_state.model_parallel_is_initialized(): + return parallel_state.get_tensor_model_parallel_rank() + return 0 + + +def _build_split_axis_by_parameter(layout: _EagleModelLayout) -> dict[str, int]: + split_axis_by_parameter = { + "eagle_module.fc.weight": 0, + } + if layout.lm_head_key is not None: + split_axis_by_parameter[layout.lm_head_key] = 0 + for layer in layout.layers: + split_axis_by_parameter[layer.qkv_weight_key] = 0 + split_axis_by_parameter[layer.proj_weight_key] = 1 + split_axis_by_parameter[layer.fc1_weight_key] = 0 + split_axis_by_parameter[layer.fc2_weight_key] = 1 + return split_axis_by_parameter + + +def _shard_to_local_tp( + parameter_name: str, + tensor: Tensor, + model_state: Mapping[str, Tensor], + split_axis_by_parameter: Mapping[str, int], + tp_rank: int, +) -> Tensor: + target = model_state.get(parameter_name) + if target is None: + return tensor + + if tensor.shape == target.shape: + return tensor.to(dtype=target.dtype) + + split_axis = split_axis_by_parameter.get(parameter_name) + if split_axis is None: + raise RuntimeError( + f"[draft] Unexpected shape mismatch for non-TP key '{parameter_name}': " + f"checkpoint={tuple(tensor.shape)} model={tuple(target.shape)}" + ) + + full_dim = tensor.shape[split_axis] + local_dim = target.shape[split_axis] + if local_dim <= 0 or full_dim % local_dim != 0: + raise RuntimeError( + f"[draft] Cannot infer TP sharding for '{parameter_name}': " + f"checkpoint={tuple(tensor.shape)} model={tuple(target.shape)}" + ) + + inferred_tp = full_dim // local_dim + if tp_rank >= inferred_tp: + raise RuntimeError( + f"[draft] tp_rank={tp_rank} out of range for key '{parameter_name}' " + f"(inferred_tp={inferred_tp})" + ) + + local_shard = torch.chunk(tensor, inferred_tp, dim=split_axis)[tp_rank] + local_shard = local_shard.contiguous() + if local_shard.shape != target.shape: + raise RuntimeError( + f"[draft] Invalid TP shard shape for '{parameter_name}': " + f"got={tuple(local_shard.shape)} expected={tuple(target.shape)}" + ) + return local_shard.to(dtype=target.dtype) + + +def _assign_optional_layer_weight( + *, + model_key: str | None, + hf_weight: Tensor, + mapped_state: StateDict, +) -> bool: + if model_key is None: + return False + mapped_state[model_key] = hf_weight + return True + + +def _map_layer_hf_weight( + layer_key: str, + hf_weight: Tensor, + layer: _EagleLayerLayout, + mapped_state: StateDict, + pending_weights: _PendingLayerWeights, +) -> None: + checkpoint_key = f"{layer.checkpoint_prefix}.{layer_key}" + + if layer_key == "self_attn.qkv_proj.weight": + pending_weights.qkv_weight = hf_weight + elif layer_key == "self_attn.q_proj.weight": + pending_weights.q_weight = hf_weight + elif layer_key == "self_attn.k_proj.weight": + pending_weights.k_weight = hf_weight + elif layer_key == "self_attn.v_proj.weight": + pending_weights.v_weight = hf_weight + elif layer_key == "self_attn.o_proj.weight": + mapped_state[layer.proj_weight_key] = hf_weight + elif layer_key == "mlp.gate_up_proj.weight": + pending_weights.fc1_weight = hf_weight + elif layer_key == "mlp.gate_proj.weight": + pending_weights.gate_weight = hf_weight + elif layer_key == "mlp.up_proj.weight": + pending_weights.up_weight = hf_weight + elif layer_key == "mlp.down_proj.weight": + mapped_state[layer.fc2_weight_key] = hf_weight + elif layer_key == "hidden_norm.weight": + _assign_optional_layer_weight( + model_key=layer.hidden_norm_key, + hf_weight=hf_weight, + mapped_state=mapped_state, + ) + elif layer_key == "input_layernorm.weight": + _assign_optional_layer_weight( + model_key=layer.input_layernorm_key, + hf_weight=hf_weight, + mapped_state=mapped_state, + ) + elif layer_key == "post_attention_layernorm.weight": + _assign_optional_layer_weight( + model_key=layer.post_attention_layernorm_key, + hf_weight=hf_weight, + mapped_state=mapped_state, + ) + else: + raise RuntimeError( + f"[draft] Unsupported Eagle checkpoint key '{checkpoint_key}'." + ) + + +def _map_hf_state_to_eagle_state( + hf_state_dict: Mapping[str, Tensor], + model_state: Mapping[str, Tensor], + layout: _EagleModelLayout, + checkpoint_source: str, +) -> StateDict: + mapped_state: StateDict = {} + pending_weights_by_layer = { + layer.layer_index: _PendingLayerWeights() for layer in layout.layers + } + layers_by_index = layout.layer_by_index + + for raw_hf_key, hf_weight in hf_state_dict.items(): + hf_key = _normalize_hf_key(raw_hf_key) + + if hf_key == "fc.weight": + mapped_state["eagle_module.fc.weight"] = hf_weight + continue + if hf_key == "norm.weight": + if layout.final_norm_key is None: + raise RuntimeError( + "[draft] Checkpoint contains 'norm.weight' but the Eagle model " + "does not expose a matching final norm." + ) + mapped_state[layout.final_norm_key] = hf_weight + continue + if hf_key in {"lm_head.weight", "eagle_output_layer.weight"}: + if layout.lm_head_key is None: + raise RuntimeError( + "[draft] Checkpoint contains draft LM-head weights but the " + "Eagle model does not expose a matching output layer." + ) + mapped_state[layout.lm_head_key] = hf_weight + continue + + parsed_layer_key = _parse_layer_checkpoint_key(hf_key) + if parsed_layer_key is None: + continue + + layer_index, layer_key = parsed_layer_key + layer = layers_by_index.get(layer_index) + if layer is None: + raise RuntimeError( + f"[draft] Checkpoint '{checkpoint_source}' contains weights for " + f"layer {layer_index}, but the Eagle model only exposes layers " + f"{sorted(layers_by_index)}." + ) + + _map_layer_hf_weight( + layer_key=layer_key, + hf_weight=hf_weight, + layer=layer, + mapped_state=mapped_state, + pending_weights=pending_weights_by_layer[layer_index], + ) + + tp_rank = _get_tp_rank() + for layer in layout.layers: + pending_weights_by_layer[layer.layer_index].apply_to( + mapped_state, + layer, + model_state=model_state, + tp_rank=tp_rank, + ) + + if not mapped_state: + raise RuntimeError( + f"[draft] No Eagle weights were mapped from checkpoint " + f"'{checkpoint_source}'." + ) + + split_axis_by_parameter = _build_split_axis_by_parameter(layout) + for parameter_name in list(mapped_state): + mapped_state[parameter_name] = _shard_to_local_tp( + parameter_name=parameter_name, + tensor=mapped_state[parameter_name], + model_state=model_state, + split_axis_by_parameter=split_axis_by_parameter, + tp_rank=tp_rank, + ) + + return mapped_state + + +def load_hf_weights_to_eagle( + model: torch.nn.Module, + model_name: str, +) -> tuple[list[str], list[str]]: + """Load HF Eagle weights from a local path or Hub repo into a draft model.""" + if not model_name or not model_name.strip(): + raise ValueError( + "load_hf_weights_to_eagle requires a non-empty model name or path." + ) + + hf_state_dict = _load_checkpoint_state(model_name) + model_state = model.state_dict() + layout = _EagleModelLayout.detect(model_state) + new_state = _map_hf_state_to_eagle_state( + hf_state_dict=hf_state_dict, + model_state=model_state, + layout=layout, + checkpoint_source=model_name, + ) + + return model.load_state_dict(new_state, strict=False) + + +def _require_state_tensor( + source_state: Mapping[str, Tensor], + parameter_name: str, +) -> Tensor: + if parameter_name not in source_state: + raise RuntimeError( + f"[draft] Missing required Eagle parameter '{parameter_name}' while " + "exporting weights." + ) + return source_state[parameter_name] + + +def find_draft_owner_chunk(model: list[MegatronModule]) -> MegatronModule | None: + """Return the post-process chunk that should own the nested draft model.""" + for model_chunk in reversed(model): + if getattr(model_chunk, "post_process", False): + return model_chunk + language_model = getattr(model_chunk, "language_model", None) + if language_model is not None and getattr( + language_model, "post_process", False + ): + return model_chunk + return None + + +def get_attached_draft_model(model: list[MegatronModule]) -> MegatronModule | None: + """Find an already attached draft model after Megatron wrapping has been applied.""" + for model_chunk in reversed(model): + unwrapped_chunk = unwrap_model(model_chunk) + draft_model = getattr(unwrapped_chunk, "draft_model", None) + if draft_model is not None: + return draft_model + return None + + +def _export_layer_weights_to_hf( + *, + source_state: Mapping[str, Tensor], + layer: _EagleLayerLayout, + q_dim: int, + kv_dim: int, + hidden_size: int, + ffn_hidden_size: int, +) -> list[tuple[str, Tensor]]: + layer_prefix = layer.checkpoint_prefix + hf_state: list[tuple[str, Tensor]] = [] + + if layer.hidden_norm_key is not None: + hf_state.append( + ( + f"{layer_prefix}.hidden_norm.weight", + _require_state_tensor(source_state, layer.hidden_norm_key), + ) + ) + + if layer.input_layernorm_key is not None: + hf_state.append( + ( + f"{layer_prefix}.input_layernorm.weight", + _require_state_tensor(source_state, layer.input_layernorm_key), + ) + ) + + q_proj, k_proj, v_proj = _gather_tp_qkv_weight( + _require_state_tensor(source_state, layer.qkv_weight_key), + q_dim=q_dim, + kv_dim=kv_dim, + ) + hf_state.append((f"{layer_prefix}.self_attn.q_proj.weight", q_proj)) + hf_state.append((f"{layer_prefix}.self_attn.k_proj.weight", k_proj)) + hf_state.append((f"{layer_prefix}.self_attn.v_proj.weight", v_proj)) + + o_proj = _gather_tp_weight_if_needed( + _require_state_tensor(source_state, layer.proj_weight_key), + (hidden_size, hidden_size), + split_axis=1, + ) + hf_state.append((f"{layer_prefix}.self_attn.o_proj.weight", o_proj)) + + if layer.post_attention_layernorm_key is not None: + hf_state.append( + ( + f"{layer_prefix}.post_attention_layernorm.weight", + _require_state_tensor(source_state, layer.post_attention_layernorm_key), + ) + ) + + gate_proj, up_proj = _gather_tp_gate_up_weight( + _require_state_tensor(source_state, layer.fc1_weight_key), + ffn_hidden_size=ffn_hidden_size, + ) + hf_state.append((f"{layer_prefix}.mlp.gate_proj.weight", gate_proj)) + hf_state.append((f"{layer_prefix}.mlp.up_proj.weight", up_proj)) + + down_proj = _gather_tp_weight_if_needed( + _require_state_tensor(source_state, layer.fc2_weight_key), + (hidden_size, ffn_hidden_size), + split_axis=1, + ) + hf_state.append((f"{layer_prefix}.mlp.down_proj.weight", down_proj)) + + return hf_state + + +def export_eagle_weights_to_hf( + model: torch.nn.Module, +) -> list[tuple[str, Tensor]]: + """Export the standalone Eagle draft model to HF naming.""" + unwrapped_model = unwrap_model(model) + source_state = unwrapped_model.state_dict() + config = unwrapped_model.config + layout = _EagleModelLayout.detect(source_state) + + q_dim = config.num_attention_heads * config.kv_channels + kv_dim = config.num_query_groups * config.kv_channels + ffn_hidden_size = config.ffn_hidden_size + num_aux_hidden_states = _get_num_aux_hidden_states(config) + + fc_weight = _gather_tp_weight_if_needed( + _require_state_tensor(source_state, "eagle_module.fc.weight"), + ( + config.hidden_size, + config.hidden_size * num_aux_hidden_states, + ), + split_axis=0, + ) + hf_state: list[tuple[str, Tensor]] = [("fc.weight", fc_weight)] + + for layer in layout.layers: + hf_state.extend( + _export_layer_weights_to_hf( + source_state=source_state, + layer=layer, + q_dim=q_dim, + kv_dim=kv_dim, + hidden_size=config.hidden_size, + ffn_hidden_size=ffn_hidden_size, + ) + ) + + if layout.final_norm_key is not None: + hf_state.append( + ( + "norm.weight", + _require_state_tensor(source_state, layout.final_norm_key), + ) + ) + if layout.lm_head_key is not None: + hf_state.append( + ( + "lm_head.weight", + _gather_tp_weight_if_needed( + _require_state_tensor(source_state, layout.lm_head_key), + (config.draft_vocab_size, config.hidden_size), + split_axis=0, + ), + ) + ) + + return hf_state + + +def get_policy_lm_head_weight(policy_model_chunk: MegatronModule) -> torch.Tensor: + """Return the local policy LM-head shard for draft initialization.""" + unwrapped_policy_model = unwrap_model(policy_model_chunk) + if getattr(unwrapped_policy_model, "share_embeddings_and_output_weights", False): + return unwrapped_policy_model.shared_embedding_or_output_weight() + return unwrapped_policy_model.output_layer.weight + + +def _get_draft_output_layer(draft_model: MegatronModule): + draft_output_layer = getattr( + getattr(draft_model, "eagle_module", None), "eagle_output_layer", None + ) + if draft_output_layer is None: + raise RuntimeError( + "[draft] Draft model was configured with has_lm_head=True but does not " + "expose eagle_output_layer." + ) + return draft_output_layer + + +def _get_draft_to_target_token_mapping( + draft_model: MegatronModule, + device: torch.device, +) -> torch.Tensor: + draft_vocab_size = int(draft_model.config.draft_vocab_size) + reverse_mapping = torch.arange(draft_vocab_size, device=device, dtype=torch.long) + d2t = getattr(draft_model.eagle_module, "d2t", None) + if d2t is not None: + reverse_mapping = reverse_mapping + d2t.to(device=device, dtype=torch.long) + return reverse_mapping + + +def copy_policy_lm_head_to_draft( + *, + draft_model: MegatronModule, + policy_model_chunk: MegatronModule, +) -> None: + """Initialize the draft LM head from the policy LM head shard.""" + draft_output_layer = _get_draft_output_layer(draft_model) + tp_group = getattr(draft_output_layer, "tp_group", None) or getattr( + draft_output_layer, "_tp_group", None + ) + policy_lm_head_weight = get_policy_lm_head_weight(policy_model_chunk).detach() + policy_lm_head_weight = _gather_tp_weight_if_needed(policy_lm_head_weight, tp_group) + draft_token_mapping = _get_draft_to_target_token_mapping( + draft_model, + device=policy_lm_head_weight.device, + ) + if draft_token_mapping.numel() == 0: + raise RuntimeError("[draft] Draft token mapping is empty.") + if int(draft_token_mapping.max().item()) >= policy_lm_head_weight.shape[0]: + raise RuntimeError( + "[draft] Cannot initialize draft LM head from policy LM head because " + f"the draft token mapping references policy vocab index {int(draft_token_mapping.max().item())}, " + f"but the gathered policy LM head only has {policy_lm_head_weight.shape[0]} rows." + ) + + selected_policy_weight = policy_lm_head_weight.index_select(0, draft_token_mapping) + if tp_group is not None and dist.is_initialized(): + tp_world_size = dist.get_world_size(tp_group) + if tp_world_size > 1: + if selected_policy_weight.shape[0] % tp_world_size != 0: + raise RuntimeError( + "[draft] Cannot shard selected policy LM head rows across TP " + f"world size {tp_world_size}: rows={selected_policy_weight.shape[0]}." + ) + tp_rank = dist.get_rank(tp_group) + selected_policy_weight = torch.chunk( + selected_policy_weight, + tp_world_size, + dim=0, + )[tp_rank].contiguous() + + if draft_output_layer.weight.shape != selected_policy_weight.shape: + raise RuntimeError( + "[draft] Cannot initialize draft LM head from policy LM head because " + f"their local shard shapes differ after draft-vocab selection: " + f"draft={tuple(draft_output_layer.weight.shape)} " + f"policy_selected={tuple(selected_policy_weight.shape)}." + ) + + with torch.no_grad(): + draft_output_layer.weight.copy_( + selected_policy_weight.to( + device=draft_output_layer.weight.device, + dtype=draft_output_layer.weight.dtype, + ) + ) + + +def build_draft_model( + model_provider, + draft_config: dict[str, Any], + pg_collection: ProcessGroupCollection, + policy_model_chunk: MegatronModule, +) -> MegatronModule | None: + """Build an Eagle draft model before parent mixed-precision/DDP wrapping.""" + if not draft_config["enabled"]: + return None + + from transformers import AutoConfig + + from nemo_rl.models.megatron.draft.eagle import EagleModel + from nemo_rl.models.megatron.draft.hidden_capture import ( + get_eagle3_aux_hidden_state_layers, + ) + + model_name = draft_config.get("model_name") + hf_config = AutoConfig.from_pretrained(model_name).to_dict() if model_name else {} + draft_num_layers = draft_config.get("num_layers") + config = TransformerConfig( + normalization="RMSNorm", + activation_func=torch.nn.functional.silu, + gated_linear_unit=True, + hidden_dropout=0.0, + attention_softmax_in_fp32=False, + tensor_model_parallel_size=model_provider.tensor_model_parallel_size, + pipeline_model_parallel_size=model_provider.pipeline_model_parallel_size, + expert_tensor_parallel_size=model_provider.expert_tensor_parallel_size, + sequence_parallel=model_provider.sequence_parallel, + use_cpu_initialization=model_provider.use_cpu_initialization, + fp16=model_provider.fp16, + bf16=model_provider.bf16, + params_dtype=model_provider.params_dtype, + pipeline_dtype=model_provider.pipeline_dtype, + num_layers=( + hf_config.get("num_hidden_layers", 1) + if model_name is not None + else draft_num_layers or 1 + ), + ffn_hidden_size=hf_config.get( + "intermediate_size", model_provider.ffn_hidden_size + ), + num_attention_heads=hf_config.get( + "num_attention_heads", model_provider.num_attention_heads + ), + kv_channels=hf_config.get("head_dim", model_provider.kv_channels), + num_query_groups=hf_config.get( + "num_key_value_heads", model_provider.num_query_groups + ), + init_method_std=model_provider.init_method_std, + layernorm_epsilon=hf_config.get( + "rms_norm_eps", model_provider.layernorm_epsilon + ), + add_bias_linear=hf_config.get("mlp_bias", model_provider.add_bias_linear), + attention_dropout=hf_config.get( + "attention_dropout", model_provider.attention_dropout + ), + ) + + config.transformer_layer_spec = None + config.hidden_size = hf_config.get("hidden_size", model_provider.hidden_size) + config.vocab_size = hf_config.get("vocab_size", model_provider.vocab_size) + config.draft_vocab_size = hf_config.get("draft_vocab_size", config.vocab_size) + config.seq_length = model_provider.seq_length + config.gradient_accumulation_fusion = False + config.position_embedding_type = hf_config.get( + "position_embedding_type", model_provider.position_embedding_type + ) + config.rotary_percent = model_provider.rotary_percent + config.rotary_base = hf_config.get("rope_theta", model_provider.rotary_base) + config.rope_scaling = ( + "rope_scaling" in hf_config if hf_config else model_provider.rope_scaling + ) + config.rope_scaling_factor = ( + hf_config.get("rope_scaling", {}).get("factor") + if hf_config + else model_provider.rope_scaling_factor + ) + + config.use_input_layernorm_in_first_layer = hf_config.get( + "use_input_layernorm_in_first_layer", True + ) + config.use_last_layernorm = hf_config.get("use_last_layernorm", True) + config.use_aux_hidden_state = hf_config.get("use_aux_hidden_state", True) + if model_name is not None: + config.eagle_aux_hidden_state_layer_ids = hf_config.get( + "eagle_aux_hidden_state_layer_ids", [] + ) + else: + config.eagle_aux_hidden_state_layer_ids = ( + draft_config.get("aux_layer_indices") or [] + ) + if ( + config.use_aux_hidden_state + and len(config.eagle_aux_hidden_state_layer_ids) == 0 + ): + config.eagle_aux_hidden_state_layer_ids = get_eagle3_aux_hidden_state_layers( + model_provider.num_layers + ) + + config.parallel_draft_step = 1 + config.use_mtp_layernorm = config.parallel_draft_heads_num_layers = None + config.has_lm_head = True + + draft_model = EagleModel(config=config) + tp_group = getattr(pg_collection, "tp", None) + if tp_group is not None: + for module in draft_model.modules(): + if hasattr(module, "pg_collection"): + module.pg_collection = pg_collection + if hasattr(module, "_pg_collection"): + module._pg_collection = pg_collection + if hasattr(module, "tp_group"): + module.tp_group = tp_group + if hasattr(module, "_tp_group"): + module._tp_group = tp_group + + if model_name is not None: + missing_keys, unexpected_keys = load_hf_weights_to_eagle( + draft_model, model_name + ) + draft_lm_head_key = "eagle_module.eagle_output_layer.weight" + if draft_lm_head_key in missing_keys: + copy_policy_lm_head_to_draft( + draft_model=draft_model, + policy_model_chunk=policy_model_chunk, + ) + missing_keys = [key for key in missing_keys if key != draft_lm_head_key] + print( + "[draft] Draft checkpoint did not contain lm_head.weight; " + "initialized draft LM head from the policy output layer." + ) + if missing_keys: + print(f"[draft] Missing keys after draft load: {missing_keys}") + if unexpected_keys: + print(f"[draft] Unexpected keys after draft load: {unexpected_keys}") + else: + copy_policy_lm_head_to_draft( + draft_model=draft_model, + policy_model_chunk=policy_model_chunk, + ) + print("[draft] Initialized draft LM head from the policy output layer.") + + return draft_model diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 6a58d930d5..8376fc565b 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -15,7 +15,7 @@ import os import time import warnings -from typing import Any, Optional, TypeVar +from typing import Any, Callable, Optional, TypeVar import torch from megatron.bridge import AutoBridge @@ -23,6 +23,7 @@ from megatron.bridge.peft.lora import LoRA from megatron.bridge.training import fault_tolerance from megatron.bridge.training.checkpointing import ( + _load_checkpoint_from_path, checkpoint_exists, init_checkpointing_context, load_checkpoint, @@ -48,6 +49,7 @@ ) from megatron.bridge.training.state import GlobalState from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer +from megatron.bridge.training.utils.pg_utils import get_pg_collection from megatron.bridge.utils.instantiate_utils import InstantiationMode from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size from megatron.core import parallel_state @@ -73,6 +75,11 @@ from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.models.megatron.community_import import import_model_from_hf_name from nemo_rl.models.megatron.config import ModelAndOptimizerState, RuntimeConfig +from nemo_rl.models.megatron.draft.utils import ( + build_draft_model, + find_draft_owner_chunk, + get_attached_draft_model, +) from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.utils import ( configure_dynamo_cache, @@ -647,6 +654,67 @@ def _create_megatron_config( ) +def _create_draft_pre_wrap_hook( + policy_cfg: PolicyConfig, + megatron_cfg: ConfigContainer, + state: GlobalState, + *, + preload_policy_from_pretrained: bool, +) -> Callable[[list[MegatronModule]], list[MegatronModule]]: + """Create the hook that attaches draft weights before mixed-precision/DDP wrapping.""" + draft_cfg = policy_cfg["draft"] + + def draft_pre_wrap_hook(model: list[MegatronModule]) -> list[MegatronModule]: + """Optionally preload the base policy, then attach the draft module to the owner chunk.""" + if not draft_cfg["enabled"]: + return model + + # Base pretrained checkpoints do not contain draft weights, so load the + # policy weights before attaching the nested draft module. + if preload_policy_from_pretrained: + pretrained_checkpoint = megatron_cfg.checkpoint.pretrained_checkpoint + if pretrained_checkpoint is None or not checkpoint_exists( + pretrained_checkpoint + ): + raise ValueError( + f"Invalid pretrained checkpoint directory found: {pretrained_checkpoint}" + ) + megatron_cfg.checkpoint.finetune = True + _load_checkpoint_from_path( + load_dir=pretrained_checkpoint, + state=state, + model=model, + optimizer=None, + opt_param_scheduler=None, + checkpointing_context={}, + skip_load_to_model_and_opt=False, + ignore_ckpt_step=True, + ) + + draft_owner = find_draft_owner_chunk(model) + if draft_owner is None: + return model + + if getattr(draft_owner, "draft_model", None) is not None: + raise RuntimeError( + "Policy model chunk already has an attached `draft_model`." + ) + + pg_collection = get_pg_collection(model) + draft_model = build_draft_model( + megatron_cfg.model, + draft_config=draft_cfg, + pg_collection=pg_collection, + policy_model_chunk=draft_owner, + ) + if draft_model is not None: + setattr(draft_owner, "draft_model", draft_model) + + return model + + return draft_pre_wrap_hook + + def setup_model_and_optimizer( policy_cfg: PolicyConfig, megatron_cfg: ConfigContainer, @@ -709,6 +777,22 @@ def setup_model_and_optimizer( pre_wrap_hook = [] use_peft = policy_cfg["megatron_cfg"].get("peft", {}).get("enabled", False) + draft_cfg = policy_cfg["draft"] + draft_enabled = draft_cfg["enabled"] + resume_checkpoint_exists = ( + megatron_cfg.checkpoint.load is not None + and checkpoint_exists(megatron_cfg.checkpoint.load) + ) + pretrained_checkpoint_exists = ( + megatron_cfg.checkpoint.pretrained_checkpoint is not None + and checkpoint_exists(megatron_cfg.checkpoint.pretrained_checkpoint) + ) + preload_policy_from_pretrained_for_draft = ( + draft_enabled + and not use_peft # The PEFT pre-wrap hook loads the pretrained base policy before adapters are attached. + and not resume_checkpoint_exists # Resume checkpoints already carry the attached draft module state. + and pretrained_checkpoint_exists + ) mixed_precision_wrapper = Float16Module if policy_cfg["megatron_cfg"]["freeze_moe_router"]: @@ -767,6 +851,15 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]: pre_wrap_hook.extend([composed_peft_hook]) + if draft_enabled: + draft_pre_wrap_hook = _create_draft_pre_wrap_hook( + policy_cfg, + megatron_cfg, + state, + preload_policy_from_pretrained=preload_policy_from_pretrained_for_draft, + ) + pre_wrap_hook.extend([draft_pre_wrap_hook]) + # Model, optimizer, and learning rate. pg_collection = ProcessGroupCollection.use_mpu_process_groups() setattr(megatron_cfg.model, "_pg_collection", pg_collection) @@ -784,6 +877,7 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]: mixed_precision_wrapper=mixed_precision_wrapper, pg_collection=pg_collection, ) + if load_optimizer: optimizer, scheduler = setup_optimizer( optimizer_config=megatron_cfg.optimizer, @@ -799,21 +893,15 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]: torch.distributed.barrier() if megatron_cfg.peft is not None: - should_load_checkpoint = ( - megatron_cfg.checkpoint.load is not None - and checkpoint_exists(megatron_cfg.checkpoint.load) - ) + should_load_checkpoint = resume_checkpoint_exists if should_load_checkpoint: # The finetune toggle is explicitly set to True in order to avoid loading optimizer and RNG states # This is switched off here in order to load these states from the checkpoint megatron_cfg.checkpoint.finetune = False else: - should_load_checkpoint = ( - megatron_cfg.checkpoint.load is not None - and checkpoint_exists(megatron_cfg.checkpoint.load) - ) or ( - megatron_cfg.checkpoint.pretrained_checkpoint is not None - and checkpoint_exists(megatron_cfg.checkpoint.pretrained_checkpoint) + should_load_checkpoint = resume_checkpoint_exists or ( + pretrained_checkpoint_exists + and not preload_policy_from_pretrained_for_draft ) # Load checkpoint if applicable @@ -829,6 +917,8 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]: print("Checkpoint loaded") torch.distributed.barrier() + draft_model = get_attached_draft_model(model) + # Set the param sync function for the model param_sync_func = None if megatron_cfg.ddp.overlap_param_gather and megatron_cfg.ddp.align_param_gather: @@ -846,6 +936,7 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]: scheduler, checkpointing_context, param_sync_func, + draft_model=draft_model, ) diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index 248c85f3ff..dc5c03cbdd 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -34,6 +34,7 @@ need_top_k_or_top_p_filtering, ) from nemo_rl.algorithms.loss import ( + DraftLossWrapper, SequencePackingFusionLossWrapper, SequencePackingLossWrapper, prepare_loss_input, @@ -49,7 +50,11 @@ from_parallel_logits_to_logprobs, from_parallel_logits_to_logprobs_packed_sequences, ) +from nemo_rl.models.megatron.config import MegatronModule from nemo_rl.models.megatron.data import ProcessedMicrobatch +from nemo_rl.models.megatron.draft.hidden_capture import ( + get_capture_context, +) from nemo_rl.models.policy import PolicyConfig # Union type for any post-processing function (defined after classes below) @@ -143,6 +148,8 @@ def forward_with_post_processing_fn( global_valid_toks: Optional[torch.Tensor] = None, sampling_params: Optional[TrainingSamplingParams] = None, straggler_timer: Optional[StragglerDetector] = None, + draft_model: Optional[MegatronModule] = None, + enable_hidden_capture: Optional[bool] = False, use_linear_ce_fusion_loss: bool = False, ) -> Tuple[torch.Tensor, Callable]: """Perform forward pass with pre-processed microbatch and return output tensor and post-processing function. @@ -160,6 +167,8 @@ def forward_with_post_processing_fn( global_valid_toks: Global valid token count for loss normalization sampling_params: Sampling parameters (top-k, top-p, temperature) straggler_timer: Straggler detector for profiling the forward pass + draft_model: Draft model for online draft model training + enable_hidden_capture: Whether to enable hidden state capture for draft model training Returns: tuple: (output_tensor, post_processing_fn_wrapped) @@ -178,17 +187,36 @@ def forward_with_post_processing_fn( packed_seq_params = processed_mb.packed_seq_params cu_seqlens_padded = processed_mb.cu_seqlens_padded - output_tensor = model_forward( - model=model, - data_dict=data_dict, - input_ids_cp_sharded=input_ids_cp_sharded, - position_ids=position_ids, - attention_mask=attention_mask, - packed_seq_params=packed_seq_params, - defer_fp32_logits=defer_fp32_logits, - straggler_timer=straggler_timer, - use_linear_ce_fusion_loss=use_linear_ce_fusion_loss, - ) + # Insert hook to capture hidden states and embeddings for draft model training if draft_model is provided + capture_context, capture = get_capture_context(model, enable_hidden_capture) + with capture_context: + output_tensor = model_forward( + model=model, + data_dict=data_dict, + input_ids_cp_sharded=input_ids_cp_sharded, + position_ids=position_ids, + attention_mask=attention_mask, + packed_seq_params=packed_seq_params, + defer_fp32_logits=defer_fp32_logits, + straggler_timer=straggler_timer, + use_linear_ce_fusion_loss=use_linear_ce_fusion_loss, + ) + + if capture is not None: + from megatron.core.transformer.multi_token_prediction import roll_tensor + + captured_states = capture.get_captured_states() + shifted_input_embeds = roll_tensor( + captured_states.inputs_embeds, + shifts=-1, + dims=0, + cp_group=get_context_parallel_group(), + )[0] + data_dict["student_logits"] = draft_model( + hidden_states=captured_states.hidden_states, + input_embeds=shifted_input_embeds, + attention_mask=attention_mask, + ) # Apply temperature scaling only for sampling-oriented post-processors. # Loss computation should use unscaled logits. @@ -241,6 +269,8 @@ def megatron_forward_backward( global_valid_toks: Optional[torch.Tensor] = None, sampling_params: Optional[TrainingSamplingParams] = None, straggler_timer: Optional[StragglerDetector] = None, + draft_model: Optional[MegatronModule] = None, + enable_hidden_capture: Optional[bool] = False, use_linear_ce_fusion_loss: bool = False, ) -> Any: """Execute forward and backward passes using Megatron's utilities. @@ -262,6 +292,8 @@ def megatron_forward_backward( global_valid_toks: Global valid token count for loss normalization sampling_params: Sampling parameters (top-k, top-p, temperature) straggler_timer: Straggler detector for profiling the forward pass + draft_model: Draft model for online draft model training + enable_hidden_capture: Whether to enable hidden state capture for draft model training Returns: Results from the forward/backward execution @@ -274,6 +306,8 @@ def megatron_forward_backward( global_valid_toks=global_valid_toks, sampling_params=sampling_params, straggler_timer=straggler_timer, + draft_model=draft_model, + enable_hidden_capture=enable_hidden_capture, use_linear_ce_fusion_loss=use_linear_ce_fusion_loss, ) forward_backward_func = get_forward_backward_func() @@ -362,6 +396,16 @@ def __call__( vocab_parallel_group=get_tensor_model_parallel_group(), context_parallel_group=get_context_parallel_group(), ) + if "student_logits" in data_dict: + loss_fn_wrapped = DraftLossWrapper( + loss_fn=loss_fn_wrapped, + prepare_fn=prepare_loss_input_wrapped, + data_dict=data_dict, + loss_weight=float(self.cfg["draft"]["loss_weight"]), + vocab_parallel_rank=get_tensor_model_parallel_rank(), + vocab_parallel_group=get_tensor_model_parallel_group(), + context_parallel_group=get_context_parallel_group(), + ) loss_fn_wrapped = partial( loss_fn_wrapped, diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 3636e5ac64..aa2e4d310a 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -226,6 +226,22 @@ class MegatronConfig(TypedDict): linear_ce_fusion_chunk_size: NotRequired[int] +class DraftConfigDisabled(TypedDict): + """Configuration shape for the disabled draft-model training path.""" + + enabled: Literal[False] + + +class DraftConfig(TypedDict): + """Configuration for Eagle draft-model training alongside the policy model.""" + + enabled: Literal[True] + model_name: NotRequired[str | None] + loss_weight: NotRequired[float] + num_layers: NotRequired[int | None] + aux_layer_indices: NotRequired[list[int] | None] + + class TokenizerConfig(TypedDict): name: str chat_template: NotRequired[str] @@ -287,6 +303,7 @@ class PolicyConfig(TypedDict): reward_model_cfg: NotRequired[RewardModelConfig] dtensor_cfg: DTensorConfig | DTensorConfigDisabled megatron_cfg: NotRequired[MegatronConfig | MegatronConfigDisabled] + draft: NotRequired[DraftConfig | DraftConfigDisabled] hf_config_overrides: NotRequired[dict[str, Any]] dynamic_batching: DynamicBatchingConfig | DynamicBatchingConfigDisabled sequence_packing: NotRequired[SequencePackingConfig | SequencePackingConfigDisabled] diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 8545fa2bc4..ab831b23ec 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -84,11 +84,24 @@ def __init__( megatron_enable = bool(config.get("megatron_cfg", {}).get("enabled", False)) dtensor_enable = bool(config.get("dtensor_cfg", {}).get("enabled", False)) + draft_enabled = bool(config.get("draft", {}).get("enabled", False)) if megatron_enable and dtensor_enable: raise ValueError( "Configure either Megatron (policy.megatron_cfg.enabled=true) or " "DTensor (policy.dtensor_cfg.enabled=true), not both." ) + if draft_enabled and not megatron_enable: + raise ValueError( + "policy.draft.enabled=true is only supported with the Megatron backend. " + "Set policy.megatron_cfg.enabled=true or disable policy.draft." + ) + if draft_enabled and bool( + config.get("sequence_packing", {}).get("enabled", False) + ): + raise ValueError( + "policy.draft.enabled=true does not support sequence packing yet. " + "Disable policy.sequence_packing.enabled or policy.draft." + ) if megatron_enable: worker_builder_cls_fqn = "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker" tp_size = config["megatron_cfg"]["tensor_model_parallel_size"] diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index fdb141fcf8..63754867d4 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -190,6 +190,7 @@ def __init__( self.scheduler = model_and_optimizer_state.scheduler self.checkpointing_context = model_and_optimizer_state.checkpointing_context param_sync_func = model_and_optimizer_state.param_sync_func + self.draft_model = model_and_optimizer_state.draft_model # Set the param sync function for the model if needed if param_sync_func is not None: @@ -340,6 +341,8 @@ def train( global_valid_toks=global_valid_toks, sampling_params=self.sampling_params, straggler_timer=self.mcore_state.straggler_timer, + draft_model=self.draft_model, + enable_hidden_capture=self.cfg["draft"]["enabled"], use_linear_ce_fusion_loss=self.cfg["megatron_cfg"].get( "use_linear_ce_fusion_loss", False ), @@ -552,6 +555,8 @@ def use_reference_model(self): # Swap reference model state_dict to self.model for k, v in self.model.state_dict().items(): if isinstance(v, torch.Tensor): + if "draft_model." in k: + continue v.copy_(self.reference_state_dict[k]) if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: @@ -918,9 +923,11 @@ def _calculate_refit_param_info(self) -> list[tuple[str, int]]: Returns: List of (parameter_name, size_in_bytes) tuples. """ - self.refit_conversion_tasks = self.megatron_bridge.get_conversion_tasks( - [self.model] - ) + self.refit_conversion_tasks = [ + task + for task in self.megatron_bridge.get_conversion_tasks([self.model]) + if task is not None + ] param_info = [] def calculate_size_in_bytes(param, tp_size, ep_size): @@ -978,6 +985,15 @@ def _iter_params_with_optional_kv_scales( for name, tensor in base_iter: yield name, tensor + if self.draft_model is not None: + from nemo_rl.models.megatron.draft import export_eagle_weights_to_hf + + draft_weights = export_eagle_weights_to_hf( + self.draft_model, + ) + for name, tensor in draft_weights: + yield f"draft.{name}", tensor + # Check whether FP8 KV cache is enabled. use_fp8_kv_cache = False if ( diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index 12840bcc66..a86cc080ce 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -54,6 +54,7 @@ run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh run_test uv run --no-sync bash ./tests/functional/grpo_megatron.sh +run_test uv run --no-sync bash ./tests/functional/grpo_megatron_eagle3_online.sh run_test uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh run_test fast uv run --no-sync bash ./tests/functional/grpo_megatron_lora.sh run_test fast uv run --no-sync bash ./tests/functional/grpo_megatron_lora_async.sh diff --git a/tests/functional/grpo_megatron_eagle3_online.sh b/tests/functional/grpo_megatron_eagle3_online.sh new file mode 100755 index 0000000000..15f8aaf1e8 --- /dev/null +++ b/tests/functional/grpo_megatron_eagle3_online.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +if [[ -z "${NRL_EAGLE3_DRAFT_MODEL:-}" ]]; then + echo "Skipping Eagle3 online functional test: set NRL_EAGLE3_DRAFT_MODEL to a compatible Eagle3 draft checkpoint." + exit 0 +fi + +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +POLICY_MODEL=${NRL_EAGLE3_POLICY_MODEL:-meta-llama/Llama-3.2-1B-Instruct} +CONFIG_PATH=$PROJECT_ROOT/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3.yaml + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo.py \ + --config $CONFIG_PATH \ + policy.model_name="$POLICY_MODEL" \ + policy.tokenizer.name="$POLICY_MODEL" \ + policy.draft.model_name="$NRL_EAGLE3_DRAFT_MODEL" \ + policy.generation.vllm_kwargs.speculative_config.model="$NRL_EAGLE3_DRAFT_MODEL" \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +if grep -q "Speculative decoding is enabled without draft refit sync" "$RUN_LOG"; then + echo "Unexpected startup-weight warning for refit-backed Eagle3 path" + exit 1 +fi + +grep -q "Draft Loss:" "$RUN_LOG" + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'min(data["train/draft_loss"]) > 0' diff --git a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3.sh b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3.sh new file mode 100755 index 0000000000..140db9efb0 --- /dev/null +++ b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3.sh @@ -0,0 +1,60 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +GPUS_PER_NODE=4 +STEPS_PER_RUN=50 +MAX_STEPS=50 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +if [[ -z "${NRL_EAGLE3_DRAFT_MODEL:-}" ]]; then + echo "Need to set NRL_EAGLE3_DRAFT_MODEL to the path of a compatible Eagle3 draft checkpoint" + exit 1 +fi + +POLICY_MODEL=${NRL_EAGLE3_POLICY_MODEL:-meta-llama/Llama-3.2-1B-Instruct} + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + policy.model_name="$POLICY_MODEL" \ + policy.tokenizer.name="$POLICY_MODEL" \ + policy.draft.model_name="$NRL_EAGLE3_DRAFT_MODEL" \ + policy.generation.vllm_kwargs.speculative_config.model="$NRL_EAGLE3_DRAFT_MODEL" \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +if grep -q "Speculative decoding is enabled without draft refit sync" "$RUN_LOG"; then + echo "Unexpected startup-weight warning for refit-backed Eagle3 path" + exit 1 +fi + +grep -q "Draft Loss:" "$RUN_LOG" + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'min(data["train/draft_loss"]) > 0' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 61e474f2c3..21635a8915 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -22,6 +22,7 @@ tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh # Megatron tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.sh +tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3.sh tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh # Functional 32b run diff --git a/tests/unit/algorithms/test_draft_loss_wrapper.py b/tests/unit/algorithms/test_draft_loss_wrapper.py new file mode 100644 index 0000000000..362bb702ed --- /dev/null +++ b/tests/unit/algorithms/test_draft_loss_wrapper.py @@ -0,0 +1,91 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import torch + +from nemo_rl.distributed.batched_data_dict import BatchedDataDict + + +@patch("nemo_rl.algorithms.loss.wrapper.DraftCrossEntropyLossFn") +def test_draft_loss_wrapper_combines_policy_and_draft_loss(mock_draft_loss_cls): + """DraftLossWrapper should add the weighted draft loss to the policy loss.""" + from nemo_rl.algorithms.loss.wrapper import DraftLossWrapper + + policy_loss = torch.tensor(3.0) + draft_loss = torch.tensor(2.0) + metrics = {"policy_metric": 1.0} + next_token_logits = torch.randn(1, 2, 3) + data = BatchedDataDict({}) + global_valid = torch.tensor(1) + + policy_loss_fn = MagicMock(return_value=(policy_loss, metrics.copy())) + prepare_fn = MagicMock(return_value=({"prepared": torch.tensor(1.0)}, data)) + draft_loss_fn = MagicMock(return_value=draft_loss) + mock_draft_loss_cls.return_value = draft_loss_fn + + wrapper = DraftLossWrapper( + loss_fn=policy_loss_fn, + prepare_fn=prepare_fn, + data_dict=data, + loss_weight=0.5, + ) + + combined_loss, combined_metrics = wrapper( + next_token_logits=next_token_logits, + data=data, + global_valid_seqs=global_valid, + global_valid_toks=global_valid, + ) + + assert combined_loss.item() == 4.0 + assert combined_metrics["draft_loss"] == draft_loss.item() + assert combined_metrics["policy_metric"] == metrics["policy_metric"] + + +@patch("nemo_rl.algorithms.loss.wrapper.DraftCrossEntropyLossFn") +def test_draft_loss_wrapper_reports_draft_loss_when_weight_is_zero( + mock_draft_loss_cls, +): + """A zero draft-loss weight should not suppress draft-loss reporting.""" + from nemo_rl.algorithms.loss.wrapper import DraftLossWrapper + + policy_loss = torch.tensor(5.0) + draft_loss = torch.tensor(1.5) + next_token_logits = torch.randn(1, 2, 3) + data = BatchedDataDict({}) + global_valid = torch.tensor(1) + + policy_loss_fn = MagicMock(return_value=(policy_loss, {})) + prepare_fn = MagicMock(return_value=({"prepared": torch.tensor(1.0)}, data)) + draft_loss_fn = MagicMock(return_value=draft_loss) + mock_draft_loss_cls.return_value = draft_loss_fn + + wrapper = DraftLossWrapper( + loss_fn=policy_loss_fn, + prepare_fn=prepare_fn, + data_dict=data, + loss_weight=0.0, + ) + + combined_loss, metrics = wrapper( + next_token_logits=next_token_logits, + data=data, + global_valid_seqs=global_valid, + global_valid_toks=global_valid, + ) + + assert combined_loss.item() == policy_loss.item() + assert metrics["draft_loss"] == draft_loss.item() diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index b65cb2d483..f11f629293 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -16,6 +16,7 @@ import os from copy import deepcopy from pathlib import Path +from unittest.mock import MagicMock import pytest import ray @@ -143,6 +144,51 @@ } +def test_configure_generation_config_uses_real_startup_weights_without_draft_refit(): + """Speculative training should not start the drafter from dummy weights without refit.""" + vllm_config = deepcopy(basic_vllm_test_config) + vllm_config["vllm_kwargs"] = { + "speculative_config": { + "method": "eagle3", + "model": "/tmp/draft-model", + "num_speculative_tokens": 3, + } + } + tokenizer = MagicMock(pad_token_id=0, eos_token_id=1) + + with pytest.warns(UserWarning, match="Speculative decoding is enabled"): + configured = configure_generation_config( + vllm_config, + tokenizer, + is_eval=False, + has_refit_draft_weights=False, + ) + + assert configured["vllm_cfg"]["load_format"] == "auto" + + +def test_configure_generation_config_keeps_dummy_startup_weights_with_draft_refit(): + """Speculative training can keep dummy startup weights when draft refit is available.""" + vllm_config = deepcopy(basic_vllm_test_config) + vllm_config["vllm_kwargs"] = { + "speculative_config": { + "method": "eagle3", + "model": "/tmp/draft-model", + "num_speculative_tokens": 3, + } + } + tokenizer = MagicMock(pad_token_id=0, eos_token_id=1) + + configured = configure_generation_config( + vllm_config, + tokenizer, + is_eval=False, + has_refit_draft_weights=True, + ) + + assert configured["vllm_cfg"]["load_format"] == "dummy" + + def get_basic_megatron_test_config( tp: int = 1, pp: int = 1, @@ -232,6 +278,9 @@ def get_basic_megatron_test_config( "overlap_param_gather": False, "data_parallel_sharding_strategy": "optim_grads_params", }, + "draft": { + "enabled": False, + }, }, "optimizer": None, # Remove default FSDP optimizer "scheduler": None, # Remove default scheduler diff --git a/tests/unit/models/megatron/test_megatron_setup.py b/tests/unit/models/megatron/test_megatron_setup.py index 948ace54b7..877ed34cab 100644 --- a/tests/unit/models/megatron/test_megatron_setup.py +++ b/tests/unit/models/megatron/test_megatron_setup.py @@ -24,6 +24,7 @@ - Model path validation """ +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -1146,3 +1147,199 @@ def test_basic_finalize_setup( mock_auto_bridge.from_hf_pretrained.assert_called_once_with( "test-model", trust_remote_code=True ) + + +@pytest.mark.mcore +class TestDraftSetup: + """Tests for Eagle draft-model setup utilities.""" + + @staticmethod + def _build_model_provider(): + return SimpleNamespace( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_tensor_parallel_size=1, + sequence_parallel=False, + use_cpu_initialization=True, + fp16=False, + bf16=False, + params_dtype=torch.float32, + pipeline_dtype=torch.float32, + ffn_hidden_size=16, + num_attention_heads=2, + kv_channels=4, + num_query_groups=2, + init_method_std=0.02, + layernorm_epsilon=1e-5, + add_bias_linear=False, + attention_dropout=0.0, + hidden_size=8, + vocab_size=8, + seq_length=16, + position_embedding_type="rope", + rotary_percent=1.0, + rotary_base=10000, + rope_scaling=None, + rope_scaling_factor=None, + num_layers=4, + ) + + @patch("nemo_rl.models.megatron.setup.get_pg_collection") + @patch("nemo_rl.models.megatron.setup.build_unwrapped_draft_model") + def test_draft_pre_wrap_hook_attaches_only_owner_chunk( + self, mock_build_draft_model, mock_get_pg_collection + ): + """The nested draft model should attach only to the owner post-process chunk.""" + from nemo_rl.models.megatron.setup import _create_draft_pre_wrap_hook + + class DummyChunk(torch.nn.Module): + def __init__(self, *, post_process: bool = False): + super().__init__() + self.post_process = post_process + + chunks = [ + DummyChunk(post_process=False), + DummyChunk(post_process=True), + DummyChunk(post_process=False), + ] + draft_model = torch.nn.Linear(2, 2, bias=False) + mock_build_draft_model.return_value = draft_model + mock_get_pg_collection.return_value = MagicMock() + + hook = _create_draft_pre_wrap_hook( + policy_cfg={"draft": {"enabled": True, "model_name": None}}, + megatron_cfg=MagicMock(), + state=MagicMock(), + preload_policy_from_pretrained=False, + ) + + returned_model = hook(chunks) + + assert returned_model is chunks + assert getattr(chunks[0], "draft_model", None) is None + assert chunks[1].draft_model is draft_model + assert getattr(chunks[2], "draft_model", None) is None + mock_build_draft_model.assert_called_once() + assert ( + mock_build_draft_model.call_args.kwargs["policy_model_chunk"] is chunks[1] + ) + + @patch("nemo_rl.models.megatron.setup.copy_policy_lm_head_to_draft") + @patch("nemo_rl.models.megatron.draft.load_hf_weights_to_eagle") + @patch("nemo_rl.models.megatron.draft.EagleModel") + @patch("transformers.AutoConfig.from_pretrained") + def test_build_unwrapped_draft_model_falls_back_to_policy_lm_head( + self, + mock_auto_config, + mock_eagle_model, + mock_load_hf_weights, + mock_copy_lm_head, + ): + """Missing draft LM-head weights should fall back to the policy LM head.""" + from nemo_rl.models.megatron.setup import build_unwrapped_draft_model + + mock_auto_config.return_value.to_dict.return_value = { + "num_hidden_layers": 2, + "intermediate_size": 16, + "num_attention_heads": 2, + "head_dim": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-5, + "attention_dropout": 0.0, + "hidden_size": 8, + "vocab_size": 8, + "eagle_aux_hidden_state_layer_ids": [0, 2], + } + draft_model = MagicMock() + draft_model.modules.return_value = [] + mock_eagle_model.return_value = draft_model + mock_load_hf_weights.return_value = ( + ["eagle_module.eagle_output_layer.weight"], + [], + ) + policy_model_chunk = MagicMock() + + returned_model = build_unwrapped_draft_model( + model_provider=self._build_model_provider(), + draft_config={"enabled": True, "model_name": "dummy-draft"}, + pg_collection=SimpleNamespace(tp=None), + policy_model_chunk=policy_model_chunk, + ) + + assert returned_model is draft_model + mock_copy_lm_head.assert_called_once_with( + draft_model=draft_model, + policy_model_chunk=policy_model_chunk, + ) + + @patch("nemo_rl.models.megatron.setup.unwrap_model") + def test_copy_policy_lm_head_to_draft_raises_on_shape_mismatch( + self, mock_unwrap_model + ): + """Selected policy rows must match the draft LM-head shard shape.""" + from nemo_rl.models.megatron.setup import copy_policy_lm_head_to_draft + + policy_model = SimpleNamespace( + share_embeddings_and_output_weights=False, + output_layer=SimpleNamespace(weight=torch.randn(2, 4)), + ) + mock_unwrap_model.return_value = policy_model + draft_model = SimpleNamespace( + config=SimpleNamespace(draft_vocab_size=2), + eagle_module=SimpleNamespace( + eagle_output_layer=SimpleNamespace(weight=torch.zeros(3, 4)), + d2t=None, + ), + ) + + with pytest.raises(RuntimeError, match="local shard shapes differ"): + copy_policy_lm_head_to_draft( + draft_model=draft_model, + policy_model_chunk=MagicMock(), + ) + + @patch("nemo_rl.models.megatron.setup.get_pg_collection") + @patch("nemo_rl.models.megatron.setup.build_unwrapped_draft_model") + def test_attached_draft_state_is_serializable( + self, mock_build_draft_model, mock_get_pg_collection + ): + """Attached draft modules should be part of the owner chunk state_dict.""" + from nemo_rl.models.megatron.setup import _create_draft_pre_wrap_hook + + class DummyChunk(torch.nn.Module): + def __init__(self): + super().__init__() + self.post_process = True + self.base = torch.nn.Linear(2, 2, bias=False) + + mock_get_pg_collection.return_value = MagicMock() + + def attach_fresh_draft(): + chunk = DummyChunk() + hook = _create_draft_pre_wrap_hook( + policy_cfg={"draft": {"enabled": True, "model_name": None}}, + megatron_cfg=MagicMock(), + state=MagicMock(), + preload_policy_from_pretrained=False, + ) + hook([chunk]) + return chunk + + original_draft = torch.nn.Linear(2, 2, bias=False) + with torch.no_grad(): + original_draft.weight.fill_(3.14) + mock_build_draft_model.return_value = original_draft + owner_chunk = attach_fresh_draft() + state_dict = owner_chunk.state_dict() + + assert "draft_model.weight" in state_dict + + restored_draft = torch.nn.Linear(2, 2, bias=False) + mock_build_draft_model.return_value = restored_draft + restored_chunk = attach_fresh_draft() + restored_chunk.load_state_dict(state_dict) + + torch.testing.assert_close( + restored_chunk.draft_model.weight, + owner_chunk.draft_model.weight, + ) diff --git a/tests/unit/models/megatron/test_train.py b/tests/unit/models/megatron/test_train.py index eccf41defb..f4b9e36d96 100644 --- a/tests/unit/models/megatron/test_train.py +++ b/tests/unit/models/megatron/test_train.py @@ -22,6 +22,8 @@ - Loss/logprobs/topk post-processors """ +from contextlib import nullcontext +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -587,6 +589,119 @@ def test_forward_with_unknown_post_processor_raises(self, mock_model_forward): post_processing_fn=unknown_processor, ) + @patch("nemo_rl.models.megatron.train.get_capture_context") + @patch("nemo_rl.models.megatron.train.model_forward") + def test_forward_without_draft_model_does_not_inject_student_logits( + self, mock_model_forward, mock_get_capture_context + ): + """Without a draft model, the forward path should remain unchanged.""" + from nemo_rl.models.megatron.data import ProcessedMicrobatch + from nemo_rl.models.megatron.train import ( + LogprobsPostProcessor, + forward_with_post_processing_fn, + ) + + mock_model_forward.return_value = torch.randn(2, 3, 5) + mock_get_capture_context.return_value = (nullcontext(), None) + + data_dict = MagicMock() + processed_mb = ProcessedMicrobatch( + data_dict=data_dict, + input_ids=torch.tensor([[1, 2, 3]]), + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + attention_mask=torch.ones(1, 3), + position_ids=torch.tensor([[0, 1, 2]]), + packed_seq_params=None, + cu_seqlens_padded=None, + ) + post_processor = LogprobsPostProcessor( + cfg={"sequence_packing": {"enabled": False}} + ) + + with patch.object(post_processor, "__call__", return_value=MagicMock()): + output, wrapped_fn = forward_with_post_processing_fn( + data_iterator=iter([processed_mb]), + model=MagicMock(), + post_processing_fn=post_processor, + draft_model=None, + ) + + assert "student_logits" not in data_dict + assert callable(wrapped_fn) + assert torch.equal(output, mock_model_forward.return_value) + + @patch("megatron.core.transformer.multi_token_prediction.roll_tensor") + @patch("nemo_rl.models.megatron.train.get_context_parallel_group") + @patch("nemo_rl.models.megatron.train.get_capture_context") + @patch("nemo_rl.models.megatron.train.model_forward") + def test_forward_with_draft_model_rolls_input_embeds_before_draft_forward( + self, + mock_model_forward, + mock_get_capture_context, + mock_get_cp_group, + mock_roll_tensor, + ): + """Draft forward should consume the one-token-shifted input embeddings.""" + from nemo_rl.models.megatron.data import ProcessedMicrobatch + from nemo_rl.models.megatron.train import ( + LogprobsPostProcessor, + forward_with_post_processing_fn, + ) + + output_tensor = torch.randn(2, 3, 5) + student_logits = torch.randn(2, 3, 5) + hidden_states = torch.randn(3, 1, 4) + inputs_embeds = torch.randn(3, 1, 4) + shifted_embeds = torch.randn(3, 1, 4) + cp_group = MagicMock() + + mock_model_forward.return_value = output_tensor + mock_get_cp_group.return_value = cp_group + mock_roll_tensor.return_value = (shifted_embeds, None) + mock_capture = MagicMock() + mock_capture.get_captured_states.return_value = SimpleNamespace( + hidden_states=hidden_states, + inputs_embeds=inputs_embeds, + ) + mock_get_capture_context.return_value = (nullcontext(), mock_capture) + + data_dict = {} + attention_mask = torch.ones(1, 3) + processed_mb = ProcessedMicrobatch( + data_dict=data_dict, + input_ids=torch.tensor([[1, 2, 3]]), + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + attention_mask=attention_mask, + position_ids=torch.tensor([[0, 1, 2]]), + packed_seq_params=None, + cu_seqlens_padded=None, + ) + post_processor = LogprobsPostProcessor( + cfg={"sequence_packing": {"enabled": False}} + ) + draft_model = MagicMock(return_value=student_logits) + + with patch.object(post_processor, "__call__", return_value=MagicMock()): + forward_with_post_processing_fn( + data_iterator=iter([processed_mb]), + model=MagicMock(), + post_processing_fn=post_processor, + draft_model=draft_model, + ) + + mock_roll_tensor.assert_called_once_with( + inputs_embeds, + shifts=-1, + dims=0, + cp_group=cp_group, + ) + draft_model.assert_called_once_with( + hidden_states=hidden_states, + input_embeds=shifted_embeds, + attention_mask=attention_mask, + ) + assert data_dict["student_logits"] is student_logits + class TestMegatronForwardBackward: """Tests for megatron_forward_backward function.""" diff --git a/tests/unit/test_recipes_and_test_suites.py b/tests/unit/test_recipes_and_test_suites.py index 873162cbe9..37ba02e281 100644 --- a/tests/unit/test_recipes_and_test_suites.py +++ b/tests/unit/test_recipes_and_test_suites.py @@ -51,7 +51,7 @@ # Configuration keys that are allowed to be added to base configs during testing # These keys may exist in recipe configs but not in base configs, so we need to # manually add them to avoid merge conflicts during config validation -ALLOWED_ADDITIONAL_CONFIG_KEYS = ["policy.generation.vllm_kwargs"] +ALLOWED_ADDITIONAL_CONFIG_KEYS = ["policy.draft", "policy.generation.vllm_kwargs"] @pytest.fixture