-
Notifications
You must be signed in to change notification settings - Fork 300
feat: Add Eagle3 online speculative decoding support #2078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
isomap
wants to merge
14
commits into
NVIDIA-NeMo:main
Choose a base branch
from
isomap:feat/eagle3-online-specdec
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+3,109
−46
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
5a0ab9e
Add Eagle3 online speculative decoding support
isomap 4b82e3e
Refactor Eagle3 online speculative decoding around draft models
isomap c5f2a12
minor fix
isomap 075c447
TP/SP fix
isomap ddb6141
Enhance draft model integration and loss computation
isomap e715ddc
Merge branch 'main' into feat/eagle3-online-specdec
isomap 86a1ee8
minor fix
isomap 1fe14d7
Update nemo_rl/models/megatron/setup.py
isomap 79e622d
Refactor draft model configuration and integration
isomap 0aa28aa
add docstrings
isomap 5a2fc1c
add guide for Eagle3 Speculative Decoding in NeMo RL
isomap 3bcadf6
Enhance draft model functionality and loss computation
isomap ca89067
Add Eagle3 configuration and online testing support
isomap b5f5a9d
Merge branch 'main' into feat/eagle3-online-specdec
isomap File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,171 @@ | ||
| # Train with Eagle3 Speculative Decoding | ||
|
|
||
| Eagle3 speculative decoding speeds up rollout generation by running a smaller draft model in vLLM and having the policy model verify its proposals. In NeMo RL, you can either use a fixed Eagle3 draft model only for generation, or train that draft model online during RL so it stays aligned with the policy. | ||
|
|
||
| This guide covers the NeMo RL-specific runtime and training path. For a high-level overview of speculative decoding, see [An Introduction to Speculative Decoding for Reducing Latency in AI Inference](https://developer.nvidia.com/blog/an-introduction-to-speculative-decoding-for-reducing-latency-in-ai-inference/). For GRPO fundamentals, see the [GRPO guide](grpo.md). For asynchronous rollout collection, see the [Async GRPO guide](async-grpo.md). | ||
|
|
||
| ## Offline vs Online | ||
|
|
||
| - **Offline draft model**: vLLM uses a fixed Eagle3 checkpoint for speculative decoding, but the RL training loop does not update that draft model. | ||
| - **Online draft training**: NeMo RL attaches an Eagle3 draft model to the Megatron policy worker, trains it alongside the policy, and refits both policy and draft weights into vLLM. | ||
|
|
||
| Use the offline path when you already have a good drafter and only want faster rollouts. Use the online path when the policy is changing during RL and you want the drafter to track those updates. | ||
|
|
||
| ## Draft Checkpoint | ||
|
|
||
| For the best results, start from an Eagle checkpoint that was already pretrained as a draft model, then use NeMo RL's online draft loss to keep it aligned with the policy during RL. For training or adapting an Eagle checkpoint, see the [Model Optimizer speculative decoding example](https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/speculative_decoding/README.md). | ||
|
|
||
| If you are using a separately trained Eagle checkpoint, make sure its `eagle_config.json` contains: | ||
|
|
||
| ```json | ||
| { | ||
| "has_lm_head": true | ||
| } | ||
| ``` | ||
|
|
||
| NeMo RL now keeps a trainer-owned draft LM head. If the draft checkpoint contains | ||
| `lm_head.weight`, NeMo RL loads it into the draft model. If that weight is absent, | ||
| NeMo RL initializes the draft LM head from the current policy output layer instead. | ||
|
|
||
| ## Enablement | ||
|
|
||
| ### Generation Only | ||
|
|
||
| ```yaml | ||
| policy: | ||
| generation: | ||
| backend: "vllm" | ||
| vllm_kwargs: | ||
| speculative_config: | ||
| method: "eagle3" | ||
| model: /path/to/eagle3-draft | ||
| num_speculative_tokens: 3 | ||
| ``` | ||
|
|
||
| This enables Eagle3 in vLLM, but the trainer does not own or update the draft model. | ||
|
|
||
| ### Online Draft Training | ||
|
|
||
| ```yaml | ||
| policy: | ||
| megatron_cfg: | ||
| enabled: true | ||
|
|
||
| draft: | ||
| enabled: true | ||
| model_name: ${policy.generation.vllm_kwargs.speculative_config.model} | ||
| loss_weight: 1.0 | ||
|
|
||
| generation: | ||
| backend: "vllm" | ||
| vllm_kwargs: | ||
| speculative_config: | ||
| method: "eagle3" | ||
| model: /path/to/eagle3-draft | ||
| num_speculative_tokens: 3 | ||
| draft_tensor_parallel_size: 1 | ||
| ``` | ||
|
|
||
| > [!NOTE] | ||
| > Online draft training currently requires the Megatron backend. NeMo RL rejects `policy.draft.enabled=true` on the DTensor path. | ||
|
|
||
| ## How It Works | ||
|
|
||
| ### Rollout Path | ||
|
|
||
| During generation, vLLM runs the Eagle3 drafter from `policy.generation.vllm_kwargs.speculative_config`. When `policy.draft.enabled=true`, NeMo RL refits both: | ||
|
|
||
| - the policy weights into the main vLLM model | ||
| - the `draft.*` weights into the vLLM drafter | ||
|
|
||
| That keeps the rollout drafter aligned with the latest RL-updated policy instead of a stale checkpoint. | ||
|
|
||
| ### Training Path | ||
|
|
||
| During the policy forward pass, NeMo RL captures: | ||
|
|
||
| - token input embeddings | ||
| - a small set of intermediate hidden states from auxiliary policy layers | ||
|
|
||
| Those captured activations are the Eagle inputs. If `policy.draft.aux_layer_indices` is not set, NeMo RL chooses a default early/middle/late set of policy layers. The draft model then predicts logits in the same vocabulary space by reusing the policy LM head. | ||
|
|
||
| ### Draft Loss and Time-Step Alignment | ||
|
|
||
| The draft loss compares draft logits against detached policy logits, but only after aligning both sides to the same next-token event. | ||
|
|
||
| Suppose the policy input sequence is: | ||
|
|
||
| ```text | ||
| [BOS, The, cat, sat] | ||
| ``` | ||
|
|
||
| The policy forward pass produces hidden states and logits at those positions: | ||
|
|
||
| ```text | ||
| position: 0 1 2 3 | ||
| input token: [BOS] [The] [cat] [sat] | ||
| hidden state: h0 h1 h2 h3 | ||
| policy logits: p0 p1 p2 p3 | ||
| predicts: The cat sat EOS | ||
| ``` | ||
|
|
||
| For Eagle training, NeMo RL does not compare raw `p0, p1, p2, p3` directly to the raw draft output. Instead it shifts the draft inputs and teacher outputs so draft position `t` predicts the teacher distribution for position `t+1`. | ||
|
|
||
| First, it rolls the captured input embeddings left by one token: | ||
|
|
||
| ```text | ||
| original embeddings: e(BOS) e(The) e(cat) e(sat) | ||
| shifted embeddings: e(The) e(cat) e(sat) - | ||
| ``` | ||
|
|
||
| Then it rolls the detached teacher logits left by one position: | ||
|
|
||
| ```text | ||
| original teacher logits: p0 p1 p2 p3 | ||
| rolled teacher logits: p1 p2 p3 - | ||
| teacher meaning: cat sat EOS - | ||
| ``` | ||
|
|
||
| So the aligned draft-training pairs become: | ||
|
|
||
| ```text | ||
| (h0, e(The)) -> p1 | ||
| (h1, e(cat)) -> p2 | ||
| (h2, e(sat)) -> p3 | ||
| ``` | ||
|
|
||
| In words: | ||
|
|
||
| - use the hidden state at position `t` | ||
| - use the embedding of the token at position `t+1` | ||
| - predict the teacher distribution for position `t+1` | ||
|
|
||
| After this alignment, the draft loss is: | ||
|
|
||
| $$ | ||
| L_{draft} = \mathbb{E}*t \left[- \sum_v \mathrm{softmax}(z*{policy,t})*v \log \mathrm{softmax}(z*{draft,t})_v \right] | ||
| $$ | ||
|
|
||
| Here `z_{policy,t}` and `z_{draft,t}` refer to the aligned tensors after rolling, truncation, and masking, not the raw unshifted outputs of the forward pass. | ||
|
|
||
| This has the same student gradient as forward KL from the policy distribution to the draft distribution, up to the teacher entropy constant. The total training objective is: | ||
|
|
||
| $$ | ||
| L_{total} = L_{policy} + \lambda \cdot L_{draft} | ||
| $$ | ||
|
|
||
| where `lambda` is `policy.draft.loss_weight`. | ||
|
|
||
| ## Important Knobs | ||
|
|
||
| - `policy.draft.enabled`: attach and train the Eagle draft model | ||
| - `policy.draft.model_name`: checkpoint used to initialize the draft model | ||
| - `policy.draft.loss_weight`: weight on the auxiliary draft loss | ||
| - `policy.draft.aux_layer_indices`: policy layers whose hidden states feed Eagle | ||
| - `policy.generation.vllm_kwargs.speculative_config.num_speculative_tokens`: number of speculative tokens proposed by vLLM | ||
|
|
||
| ## Notes | ||
|
|
||
| - When online draft training is enabled, NeMo RL logs `draft_loss`. | ||
| - Resume checkpoints include the nested draft model state when `policy.draft.enabled=true`. | ||
| - If speculative decoding is enabled without trainer-owned draft weights, vLLM must load real draft weights at startup. When the trainer owns the draft model, the first refit pushes both policy and draft parameters. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| defaults: ./grpo-llama3.2-1b-instruct-1n4g-megatron.yaml | ||
|
|
||
| checkpointing: | ||
| checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3 | ||
|
|
||
| grpo: | ||
| num_prompts_per_step: 4 | ||
| num_generations_per_prompt: 4 | ||
|
|
||
| policy: | ||
| train_global_batch_size: 16 | ||
| train_micro_batch_size: 1 | ||
| logprob_batch_size: 1 | ||
| max_total_sequence_length: 256 | ||
| sequence_packing: | ||
| enabled: false | ||
| draft: | ||
| enabled: true | ||
| model_name: "__set_nrl_eagle3_draft_model__" | ||
| loss_weight: 1.0 | ||
| generation: | ||
| max_new_tokens: 256 | ||
| vllm_cfg: | ||
| max_model_len: 256 | ||
| vllm_kwargs: | ||
| speculative_config: | ||
| method: "eagle3" | ||
| model: ${policy.draft.model_name} | ||
| num_speculative_tokens: 3 | ||
| draft_tensor_parallel_size: 1 | ||
|
|
||
| logger: | ||
| log_dir: logs/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3 | ||
| wandb: | ||
| name: grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -12,17 +12,75 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|
|
||||
| from typing import Any, NotRequired, TypedDict, TypeVar | ||||
| from typing import Any, NotRequired, Optional, TypedDict, TypeVar | ||||
|
|
||||
| import torch | ||||
|
|
||||
| from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType, LossType | ||||
| from nemo_rl.algorithms.utils import calculate_kl, masked_mean | ||||
| from nemo_rl.distributed.batched_data_dict import BatchedDataDict | ||||
| from nemo_rl.distributed.model_utils import DistributedCrossEntropy | ||||
|
|
||||
| Tensor = TypeVar("Tensor", bound=torch.Tensor) | ||||
|
|
||||
|
|
||||
| class DraftCrossEntropyLossConfig(TypedDict): | ||||
| vocab_parallel_group: Optional[torch.distributed.ProcessGroup] | ||||
|
|
||||
|
|
||||
| class DraftCrossEntropyLossDataDict(TypedDict): | ||||
| teacher_logits: Tensor | ||||
| student_logits: Tensor | ||||
| token_mask: Tensor | ||||
| sample_mask: Tensor | ||||
| student_vocab_indices: NotRequired[Tensor] | ||||
|
|
||||
|
|
||||
| class DraftCrossEntropyLossFn(LossFunction): | ||||
| """Compute the auxiliary soft-target cross-entropy used for draft-model training.""" | ||||
|
|
||||
| loss_type = LossType.TOKEN_LEVEL | ||||
| input_type = LossInputType.DRAFT | ||||
|
|
||||
| def __init__( | ||||
| self, | ||||
| vocab_parallel_rank: Optional[int] = None, | ||||
| vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, | ||||
| ): | ||||
| self.vocab_parallel_rank = vocab_parallel_rank | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| self.vocab_parallel_group = vocab_parallel_group | ||||
|
|
||||
| def __call__( | ||||
| self, | ||||
| teacher_logits: Tensor, | ||||
| student_logits: Tensor, | ||||
| token_mask: Tensor, | ||||
| data: BatchedDataDict[DraftCrossEntropyLossDataDict], | ||||
| global_valid_seqs: torch.Tensor, | ||||
| global_valid_toks: torch.Tensor, | ||||
| ) -> torch.Tensor: | ||||
| """Reduce the masked per-token draft loss to a scalar.""" | ||||
| if self.vocab_parallel_group is not None: | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have some test to cover this path (TP>1)? if not, can you add it? |
||||
| # Soft cross entropy matches the forward-KL student gradient. | ||||
| per_token_loss = DistributedCrossEntropy.apply( | ||||
| student_logits, | ||||
| teacher_logits, | ||||
| self.vocab_parallel_group, | ||||
| False, | ||||
| ) | ||||
| else: | ||||
| teacher_probs = torch.nn.functional.softmax(teacher_logits, dim=-1) | ||||
| student_log_probs = torch.nn.functional.log_softmax(student_logits, dim=-1) | ||||
| per_token_loss = -(teacher_probs * student_log_probs).sum(dim=-1) | ||||
|
|
||||
| mask = token_mask * data["sample_mask"].unsqueeze(-1) | ||||
| return masked_mean( | ||||
| per_token_loss, | ||||
| mask, | ||||
| global_normalization_factor=global_valid_toks, | ||||
| ) | ||||
|
|
||||
|
|
||||
| class ClippedPGLossConfig(TypedDict): | ||||
| reference_policy_kl_penalty: float | ||||
| reference_policy_kl_type: str | ||||
|
|
||||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally 1n4g is for gb200, can you also add one from
examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yamlwhich is for h100?