Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions docs/guides/eagle3-speculative-decoding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Train with Eagle3 Speculative Decoding

Eagle3 speculative decoding speeds up rollout generation by running a smaller draft model in vLLM and having the policy model verify its proposals. In NeMo RL, you can either use a fixed Eagle3 draft model only for generation, or train that draft model online during RL so it stays aligned with the policy.

This guide covers the NeMo RL-specific runtime and training path. For a high-level overview of speculative decoding, see [An Introduction to Speculative Decoding for Reducing Latency in AI Inference](https://developer.nvidia.com/blog/an-introduction-to-speculative-decoding-for-reducing-latency-in-ai-inference/). For GRPO fundamentals, see the [GRPO guide](grpo.md). For asynchronous rollout collection, see the [Async GRPO guide](async-grpo.md).

## Offline vs Online

- **Offline draft model**: vLLM uses a fixed Eagle3 checkpoint for speculative decoding, but the RL training loop does not update that draft model.
- **Online draft training**: NeMo RL attaches an Eagle3 draft model to the Megatron policy worker, trains it alongside the policy, and refits both policy and draft weights into vLLM.

Use the offline path when you already have a good drafter and only want faster rollouts. Use the online path when the policy is changing during RL and you want the drafter to track those updates.

## Draft Checkpoint

For the best results, start from an Eagle checkpoint that was already pretrained as a draft model, then use NeMo RL's online draft loss to keep it aligned with the policy during RL. For training or adapting an Eagle checkpoint, see the [Model Optimizer speculative decoding example](https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/speculative_decoding/README.md).

If you are using a separately trained Eagle checkpoint, make sure its `eagle_config.json` contains:

```json
{
"has_lm_head": true
}
```

NeMo RL now keeps a trainer-owned draft LM head. If the draft checkpoint contains
`lm_head.weight`, NeMo RL loads it into the draft model. If that weight is absent,
NeMo RL initializes the draft LM head from the current policy output layer instead.

## Enablement

### Generation Only

```yaml
policy:
generation:
backend: "vllm"
vllm_kwargs:
speculative_config:
method: "eagle3"
model: /path/to/eagle3-draft
num_speculative_tokens: 3
```

This enables Eagle3 in vLLM, but the trainer does not own or update the draft model.

### Online Draft Training

```yaml
policy:
megatron_cfg:
enabled: true

draft:
enabled: true
model_name: ${policy.generation.vllm_kwargs.speculative_config.model}
loss_weight: 1.0

generation:
backend: "vllm"
vllm_kwargs:
speculative_config:
method: "eagle3"
model: /path/to/eagle3-draft
num_speculative_tokens: 3
draft_tensor_parallel_size: 1
```

> [!NOTE]
> Online draft training currently requires the Megatron backend. NeMo RL rejects `policy.draft.enabled=true` on the DTensor path.

## How It Works

### Rollout Path

During generation, vLLM runs the Eagle3 drafter from `policy.generation.vllm_kwargs.speculative_config`. When `policy.draft.enabled=true`, NeMo RL refits both:

- the policy weights into the main vLLM model
- the `draft.*` weights into the vLLM drafter

That keeps the rollout drafter aligned with the latest RL-updated policy instead of a stale checkpoint.

### Training Path

During the policy forward pass, NeMo RL captures:

- token input embeddings
- a small set of intermediate hidden states from auxiliary policy layers

Those captured activations are the Eagle inputs. If `policy.draft.aux_layer_indices` is not set, NeMo RL chooses a default early/middle/late set of policy layers. The draft model then predicts logits in the same vocabulary space by reusing the policy LM head.

### Draft Loss and Time-Step Alignment

The draft loss compares draft logits against detached policy logits, but only after aligning both sides to the same next-token event.

Suppose the policy input sequence is:

```text
[BOS, The, cat, sat]
```

The policy forward pass produces hidden states and logits at those positions:

```text
position: 0 1 2 3
input token: [BOS] [The] [cat] [sat]
hidden state: h0 h1 h2 h3
policy logits: p0 p1 p2 p3
predicts: The cat sat EOS
```

For Eagle training, NeMo RL does not compare raw `p0, p1, p2, p3` directly to the raw draft output. Instead it shifts the draft inputs and teacher outputs so draft position `t` predicts the teacher distribution for position `t+1`.

First, it rolls the captured input embeddings left by one token:

```text
original embeddings: e(BOS) e(The) e(cat) e(sat)
shifted embeddings: e(The) e(cat) e(sat) -
```

Then it rolls the detached teacher logits left by one position:

```text
original teacher logits: p0 p1 p2 p3
rolled teacher logits: p1 p2 p3 -
teacher meaning: cat sat EOS -
```

So the aligned draft-training pairs become:

```text
(h0, e(The)) -> p1
(h1, e(cat)) -> p2
(h2, e(sat)) -> p3
```

In words:

- use the hidden state at position `t`
- use the embedding of the token at position `t+1`
- predict the teacher distribution for position `t+1`

After this alignment, the draft loss is:

$$
L_{draft} = \mathbb{E}*t \left[- \sum_v \mathrm{softmax}(z*{policy,t})*v \log \mathrm{softmax}(z*{draft,t})_v \right]
$$

Here `z_{policy,t}` and `z_{draft,t}` refer to the aligned tensors after rolling, truncation, and masking, not the raw unshifted outputs of the forward pass.

This has the same student gradient as forward KL from the policy distribution to the draft distribution, up to the teacher entropy constant. The total training objective is:

$$
L_{total} = L_{policy} + \lambda \cdot L_{draft}
$$

where `lambda` is `policy.draft.loss_weight`.

## Important Knobs

- `policy.draft.enabled`: attach and train the Eagle draft model
- `policy.draft.model_name`: checkpoint used to initialize the draft model
- `policy.draft.loss_weight`: weight on the auxiliary draft loss
- `policy.draft.aux_layer_indices`: policy layers whose hidden states feed Eagle
- `policy.generation.vllm_kwargs.speculative_config.num_speculative_tokens`: number of speculative tokens proposed by vLLM

## Notes

- When online draft training is enabled, NeMo RL logs `draft_loss`.
- Resume checkpoints include the nested draft model state when `policy.draft.enabled=true`.
- If speculative decoding is enabled without trainer-owned draft weights, vLLM must load real draft weights at startup. When the trainer owns the draft model, the first refit pushes both policy and draft parameters.
7 changes: 7 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ policy:

env_vars: null

draft:
enabled: false
model_name: null
loss_weight: 0.1
num_layers: null
aux_layer_indices: null

# See docs/design-docs/sequence-packing-and-dynamic-batching.md
# for more details on dynamic batching and sequence packing.
dynamic_batching:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
defaults: ./grpo-llama3.2-1b-instruct-1n4g-megatron.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally 1n4g is for gb200, can you also add one from examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml which is for h100?


checkpointing:
checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3

grpo:
num_prompts_per_step: 4
num_generations_per_prompt: 4

policy:
train_global_batch_size: 16
train_micro_batch_size: 1
logprob_batch_size: 1
max_total_sequence_length: 256
sequence_packing:
enabled: false
draft:
enabled: true
model_name: "__set_nrl_eagle3_draft_model__"
loss_weight: 1.0
generation:
max_new_tokens: 256
vllm_cfg:
max_model_len: 256
vllm_kwargs:
speculative_config:
method: "eagle3"
model: ${policy.draft.model_name}
num_speculative_tokens: 3
draft_tensor_parallel_size: 1

logger:
log_dir: logs/grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3
wandb:
name: grpo-llama3.2-1b-instruct-1n4g-megatron-eagle3
5 changes: 4 additions & 1 deletion examples/run_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,11 @@ def main() -> None:
assert config["policy"]["generation"] is not None, (
"A generation config is required for GRPO"
)
has_refit_draft_weights = bool(config["policy"]["draft"]["enabled"])
config["policy"]["generation"] = configure_generation_config(
config["policy"]["generation"], tokenizer
config["policy"]["generation"],
tokenizer,
has_refit_draft_weights=has_refit_draft_weights,
)

# setup data
Expand Down
4 changes: 4 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,8 @@ def grpo_train(
print("\n📊 Training Results:")

print(f" • Loss: {metrics['loss']:.4f}")
if "draft_loss" in metrics:
print(f" • Draft Loss: {metrics['draft_loss']:.4f}")
print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}")
if master_config["grpo"]["use_dynamic_sampling"]:
print(f" • Avg Filtered Reward: {np.mean(rewards.numpy()):.4f}")
Expand Down Expand Up @@ -3130,6 +3132,8 @@ def async_grpo_train(

print("\n📊 Training Results:")
print(f" • Loss: {metrics['loss']:.4f}")
if "draft_loss" in metrics:
print(f" • Draft Loss: {metrics['draft_loss']:.4f}")
print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}")
print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}")
print(f" • Buffer Size: {buffer_size_current}")
Expand Down
3 changes: 3 additions & 0 deletions nemo_rl/algorithms/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DPOLossConfig,
DPOLossDataDict,
DPOLossFn,
DraftCrossEntropyLossFn,
NLLLossFn,
PreferenceLossDataDict,
PreferenceLossFn,
Expand All @@ -31,6 +32,7 @@
prepare_packed_loss_input,
)
from nemo_rl.algorithms.loss.wrapper import (
DraftLossWrapper,
SequencePackingFusionLossWrapper,
SequencePackingLossWrapper,
wrap_loss_fn_with_input_preparation,
Expand All @@ -53,5 +55,6 @@
"prepare_packed_loss_input",
"SequencePackingFusionLossWrapper",
"SequencePackingLossWrapper",
"DraftLossWrapper",
"wrap_loss_fn_with_input_preparation",
]
2 changes: 2 additions & 0 deletions nemo_rl/algorithms/loss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class LossInputType(enum.Enum):
LOGIT = "logit"
LOGPROB = "logprob"
DISTILLATION = "distillation"
DRAFT = "draft"


class LossFunction(Protocol):
Expand Down Expand Up @@ -66,6 +67,7 @@ def __call__(
- For LossInputType.LOGPROB: next_token_logprobs (torch.Tensor)
- For LossInputType.LOGIT: logits (torch.Tensor)
- For LossInputType.DISTILLATION: student_topk_logprobs, teacher_topk_logprobs, H_all (torch.Tensor)
- For LossInputType.DRAFT: teacher_logits, student_logits, mask (torch.Tensor)
Returns:
tuple: (loss, metrics)
Expand Down
60 changes: 59 additions & 1 deletion nemo_rl/algorithms/loss/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,75 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, NotRequired, TypedDict, TypeVar
from typing import Any, NotRequired, Optional, TypedDict, TypeVar

import torch

from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType, LossType
from nemo_rl.algorithms.utils import calculate_kl, masked_mean
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.model_utils import DistributedCrossEntropy

Tensor = TypeVar("Tensor", bound=torch.Tensor)


class DraftCrossEntropyLossConfig(TypedDict):
vocab_parallel_group: Optional[torch.distributed.ProcessGroup]


class DraftCrossEntropyLossDataDict(TypedDict):
teacher_logits: Tensor
student_logits: Tensor
token_mask: Tensor
sample_mask: Tensor
student_vocab_indices: NotRequired[Tensor]


class DraftCrossEntropyLossFn(LossFunction):
"""Compute the auxiliary soft-target cross-entropy used for draft-model training."""

loss_type = LossType.TOKEN_LEVEL
input_type = LossInputType.DRAFT

def __init__(
self,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.vocab_parallel_rank = vocab_parallel_rank
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vocab_parallel_rank seems not used?

Suggested change
self.vocab_parallel_rank = vocab_parallel_rank

self.vocab_parallel_group = vocab_parallel_group

def __call__(
self,
teacher_logits: Tensor,
student_logits: Tensor,
token_mask: Tensor,
data: BatchedDataDict[DraftCrossEntropyLossDataDict],
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor,
) -> torch.Tensor:
"""Reduce the masked per-token draft loss to a scalar."""
if self.vocab_parallel_group is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have some test to cover this path (TP>1)? if not, can you add it?

# Soft cross entropy matches the forward-KL student gradient.
per_token_loss = DistributedCrossEntropy.apply(
student_logits,
teacher_logits,
self.vocab_parallel_group,
False,
)
else:
teacher_probs = torch.nn.functional.softmax(teacher_logits, dim=-1)
student_log_probs = torch.nn.functional.log_softmax(student_logits, dim=-1)
per_token_loss = -(teacher_probs * student_log_probs).sum(dim=-1)

mask = token_mask * data["sample_mask"].unsqueeze(-1)
return masked_mean(
per_token_loss,
mask,
global_normalization_factor=global_valid_toks,
)


class ClippedPGLossConfig(TypedDict):
reference_policy_kl_penalty: float
reference_policy_kl_type: str
Expand Down
Loading
Loading