Skip to content

[Kimi] support Eagle3 speculative decoding for Kimi K2.5#631

Merged
valarLip merged 6 commits intoROCm:mainfrom
yhl-amd:support_kimi_draft
May 10, 2026
Merged

[Kimi] support Eagle3 speculative decoding for Kimi K2.5#631
valarLip merged 6 commits intoROCm:mainfrom
yhl-amd:support_kimi_draft

Conversation

@yhl-amd
Copy link
Copy Markdown
Contributor

@yhl-amd yhl-amd commented Apr 22, 2026

Eagle3 Speculative Decoding for Kimi K2.5

Closes ROCm/ATOM#553

1. Goal

Add Eagle3 speculative decoding support for Kimi K2.5 (MLA target) in ATOM:

  • Functional: Eagle3 draft (standard MHA, independent KV cache, aux hidden state feedback) interoperates correctly with the MLA target model
  • Performance: target model can use CUDAGraph in Eagle3 mode, on the same fast path as MTP

The draft checkpoint used in this PR is lightseekorg/kimi-k2.5-eagle3 — a 1-layer standard Llama decoder (MHA). The MLA-draft variant is not yet released upstream; the abstraction introduced here is format-agnostic and can be extended to MLA draft when available.

End result: GSM8K 5-shot acceptance rate 67.85%, accuracy 93.78% on Kimi-K2.5-MXFP4.

2. Design

Eagle3 differs from the existing MTP path in ATOM along four key dimensions:

Aspect MTP Eagle3
Draft attention Same as target (MLA) Standard MHA
Draft input Target's last hidden Concatenated aux hidden states from multiple target layers, projected through fc
KV cache Shared with target Allocated separately
Forward return Single tensor (hidden, aux_hidden_states) tuple

Three pieces of design address these differences:

  • Eagle3DraftBuilder (post-refactor: delegate ATOM KV cache subsystem to attention builders #659 protocol): a sibling attention metadata builder that owns the draft's independent non-MLA KV cache through the same three hooks ModelRunner consults for the target builder (compute_block_bytes, allocate_kv_cache_tensors, build_kv_cache_tensor). Attached to the runner from EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner has no Eagle3-specific code in its KV path; future heterogeneous-draft methods (Medusa, ReDrafter, etc.) only need to attach their own builder under the same protocol. Cache tensors land under namespaced keys (eagle3_kv_cache / eagle3_kv_scale) so they do not collide with the target builder's kv_cache / kv_scale.
  • Standard paged attention path for draft: draft and target share runner.block_size, so draft KV layout, slot_mapping values, and block_tables are all directly compatible with what target's prepare_decode wrote — no per-block-size translation needed. Pure-MHA layers (num_heads == num_kv_heads, satisfied by Eagle3 draft) are routed to paged_attention_triton (gluon) in dispatch_backend(), bypassing the aiter ASM pa_*.co family (which has a 16-bit block_id truncation issue when KV pool exceeds 65,535 blocks). In propose, the only metadata adaptation needed for the draft's first forward is re-slicing slot_mapping to len(input_ids) (after rejection it can be < bs * max_q_len, and the MHA cache-write kernel asserts slot_mapping <= q); context_lens is cloned (empirically +0.01 toks/fwd vs view, suspected async-ordering related) and bumped in-place per draft step.
  • Aux hidden state through CUDAGraph: a new graph_aux_hidden: dict[(bs, max_q_len), list[Tensor]] stores references to the aux tensors materialized inside the captured graph. After replay, slices are taken to the actual num_tokens for the drafter to consume. Type dispatch uses a use_aux_hidden_state_outputs flag set at init time, avoiding runtime isinstance checks that would pollute graph capture.

3. Comparison with the existing MTP workflow

Aspect MTP (existing) Eagle3 (this PR)
Target forward return Single tensor (hidden, aux_hidden_states)
Target CUDAGraph Captured directly Same capture path + graph_aux_hidden saves aux references; sliced to num_tokens after replay
Draft KV cache Shares physical memory with target DraftKVCache allocated independently with its own block granularity
Draft attention metadata Shared MLA metadata Inherited from target's MLA prepare_decode; only slot_mapping re-sliced to len(input_ids), and context_lens cloned (empirically +0.01 toks/fwd) before bumping in-place per draft step
Draft cross-step hidden Last hidden (same tensor) Pre-norm hidden at midlayer exit (avoids second RMSNorm at decoder entry)
Draft model forward return Single tensor (post_norm, pre_norm); logits use post_norm, next step uses pre_norm

Flow diagram

                  ┌──────────────────────────────────────┐
 Target forward   │  CUDAGraph capture / replay          │  ──► hidden
                  │  + graph_aux_hidden[(bs, q)] saves   │  ──► aux_hidden
                  └──────────────────────────────────────┘      (sliced to num_tokens after replay)
                                                            │
        ┌───────────────────────────────────────────────────┘
        ▼
  ┌─────────────────────────────────────────────────────────────────┐
  │ propose loop (i = 0 .. mtp_k-1)                                 │
  │                                                                 │
  │  ┌──────────────────────────────────────────────────┐           │
  │  │ Once: block_tables / context_lens / slot_mapping │  before   │
  │  │       ← forward_vars (scheduler-populated)       │   loop    │
  │  └──────────────────────────────────────────────────┘           │
  │                  ▼                                              │
  │  Draft forward  ─►  paged_attention_triton (gluon, MHA path)    │
  │                  ─►  (post_norm, pre_norm) two tensors          │
  │                  ▼                                              │
  │  logits     ◄── post_norm                                       │
  │  next hidden ◄── pre_norm  (midlayer exit, before final norm)   │
  │  context_lens += 1  (incremental, no rebuild)                   │
  └─────────────────────────────────────────────────────────────────┘

4. Changes

File Lines Description
atom/model_engine/model_runner.py +137 Generic delegation to runner.eagle3_draft_builder for KV pool sizing / allocation / per-module binding (no Eagle3-specific KV code in ModelRunner); CUDAGraph capture handles tuple return + graph_aux_hidden; aux layer registration
atom/models/eagle3_llama.py +301 (new) Eagle3LlamaModel, Eagle3LlamaDecoderLayer; forward returns (post_norm, pre_norm); use_sliding_window propagated from config to attention
atom/spec_decode/eagle.py +240 Eagle3DraftBuilder class (KV cache subsystem under the post-#659 builder protocol); Eagle3 load_model early-returns from a separate draft checkpoint (no embed/lm_head sharing); slot_mapping re-slice + cloned context_lens before propose loop; unpack (post_norm, pre_norm); feed pre_norm to next draft step; draft_uses_mha = hasattr(runner, "eagle3_draft_builder") drives metadata-flow predicates without is_eagle3 string compares
atom/model_engine/arg_utils.py +30 CLI args --method eagle3, --draft-model; popped in both spec-on and spec-off branches
atom/config.py +28 SpeculativeConfig parses Eagle3 fields; argument validation; default use_aux_hidden_state=True when eagle_config absent; Eagle3 architecture mapping in hf_config_override (architecture-level, not model_type-level)
atom/model_engine/scheduler.py +13 -3 spec_stats only updated when speculation actually ran (drafts proposed and validated), matching vLLM's gating; previously prefill-only steps inflated acceptance metrics
atom/models/deepseek_v2.py +24 Aux hidden state collection hooks; default layer IDs (2, N//2, N-3), aligned with vLLM
atom/model_ops/linear.py +18 -5 float16↔bfloat16 weight loading uses to() value conversion (only same-family dtypes such as fp8 variants keep view() bit reinterpretation)
atom/models/kimi_k25.py +6 Register Eagle3 aux layer interface
atom/model_ops/attention_mha.py +13 -2 Pure-MHA layers (num_heads == num_kv_heads) routed to triton/gluon, bypassing the aiter ASM pa_*.co family. (Sliding-window check tightened to > 0 to avoid -1 falsy mis-routing.)

Total: +754 / -62 across 10 files.

Integration with PR #659 (KV cache subsystem refactor)

PR #659 moved per-attention-type KV cache allocation and binding into the attention metadata builder. Eagle3 integrates as follows:

  • Target KV cache: allocated and bound through attn_metadata_builder.allocate_kv_cache_tensors() / build_kv_cache_tensor() like any other model — no Eagle3-specific code in the target builder.
  • Draft KV cache: allocated through a sibling builder (Eagle3DraftBuilder) that implements the same three hooks the target builder exposes. EagleProposer.__init__ attaches it as runner.eagle3_draft_builder, and ModelRunner sums compute_block_bytes() and dispatches allocate_kv_cache_tensors() / build_kv_cache_tensor() across both builders generically. The draft cache tensors land under namespaced keys (eagle3_kv_cache / eagle3_kv_scale) so they don't collide with the target builder's kv_cache / kv_scale.
  • Draft layer binding: the per-module loop tries eagle3_draft_builder first when model_name == "draft"; it returns None for non-MHA modules (LayerNorm, MLP, etc.) and falls through to attn_metadata_builder like any other module. The MLA target's attention modules never hit the draft builder because the loop only consults it under model_name == "draft".
  • mtp_start_layer_idx: the target MLA builder accesses this attribute through self.runner.mtp_start_layer_idx for MTP cache slot ordering. Eagle3 is not MTP — Eagle3LlamaModel has no model.model.mtp_start_layer_idx. When runner.eagle3_draft_builder is present, ModelRunner sets mtp_start_layer_idx = hf_config.num_hidden_layers (i.e. "no MTP layers above this index"), matching the behavior when no drafter is present.

5. Usage

CLI arguments

Argument Required Description
--method eagle3 yes Enable Eagle3 speculative decoding (mutually exclusive with mtp)
--draft-model <path> yes Path to the Eagle3 draft model, e.g. /data/models/kimi-k2.5-eagle3
--num-speculative-tokens <int> yes Number of autoregressive draft steps per iteration (recommended: 3)

Aux hidden-state layer ids are derived from the draft model's eagle_config
when present, otherwise fall back to the target model's
get_eagle3_aux_hidden_state_layers() (default (2, N//2, N-3)).

OpenAI-compatible server

AITER_LOG_LEVEL=WARNING python -m atom.entrypoints.openai_server \
    --model /data/models/Kimi-K2.5-MXFP4 \
    --kv_cache_dtype fp8 \
    -tp 8 \
    --method eagle3 \
    --draft-model /data/models/kimi-k2.5-eagle3 \
    --num-speculative-tokens 3

Offline inference

python -m atom.examples.simple_inference \
    --model /data/models/Kimi-K2.5-MXFP4 \
    --kv_cache_dtype fp8 \
    -tp 8 \
    --method eagle3 \
    --draft-model /data/models/kimi-k2.5-eagle3 \
    --num-speculative-tokens 3

Verification

# Server startup: HTTP /health is not sufficient — confirm GPU memory is actually allocated
rocm-smi --showmemuse        # VRAM% > 0 means model is loaded

# Run lm_eval to verify acceptance rate and accuracy
lm_eval --model local-completions \
    --model_args base_url=http://localhost:8000/v1/completions,model=Kimi-K2.5-MXFP4 \
    --tasks gsm8k_cot_zeroshot --num_fewshot 5 --batch_size 64

Acceptance rate and per-position distribution are emitted in the server log statistics line.

Fallback

  • Disable speculative decoding: drop --method, --draft-model, --num-speculative-tokens
  • Keep speculation but disable CUDAGraph: add --enforce-eager

6. Results

Test environment: Kimi-K2.5-MXFP4 + lightseekorg/kimi-k2.5-eagle3, AMD MI350 × 8 (ROCm 7.2), TP=8, num_speculative_tokens=3, aux_layer_ids=(2,30,58), lm_eval GSM8K 5-shot.

Metric ATOM vLLM
Acceptance rate 67.85% 73.0%
Avg tokens / forward 3.04 3.19
GSM8K accuracy 93.78% 92.80%

Acceptance distribution (mtp_k=3): 0/1/2/3 accepted = 15.0% / 16.4% / 18.7% / 49.9%.

Per-position acceptance (unconditional, drawn from the distribution above and from vLLM's metrics.py log line):

Position ATOM vLLM
1 85.0% 87.0%
2 68.6% 73.0%
3 49.9% 59.0%

GSM8K accuracy matches vLLM within the ±0.71% standard error, confirming numerical correctness of the draft sequence. The remaining ~5.2 pp acceptance-rate gap concentrates on positions 2 and 3, suggesting it stems from cross-step state propagation in the draft model rather than initial prediction quality. Likely contributors are bf16 numerical differences between ROCm and CUDA attention kernels and MXFP4 dequant implementation differences, both orthogonal to the scope of this PR.

7. Known limitations

  1. Draft only supports standard MHA (limited by the upstream checkpoint). If the draft config has a non-empty kv_lora_rank, startup raises NotImplementedError.
  2. The draft model itself still runs eager (only the target uses CUDAGraph).
  3. Eagle3 is not classified as MTP. Anything that derives an MTP-specific value from self.drafter.model.model.<...> (e.g. mtp_start_layer_idx) must guard on hasattr(runner, "eagle3_draft_builder") and fall back to hf_config.num_hidden_layersEagle3LlamaModel has no model attribute.

@yhl-amd yhl-amd force-pushed the support_kimi_draft branch from da4a5c8 to 46da277 Compare May 5, 2026 16:29
@yhl-amd yhl-amd force-pushed the support_kimi_draft branch 2 times, most recently from 4c921bb to dadefd7 Compare May 7, 2026 03:36
Comment thread atom/spec_decode/eagle.py Outdated
slot_mapping[:] = kv_indices[kv_indptr[1 : bs + 1] - 1]

if is_eagle3:
attn_metadata.context_lens = attn_metadata.context_lens + 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...

layer_num=layer_num,
prefix=f"{prefix}.attn",
rotary_emb=self.rotary_emb,
per_layer_sliding_window=sliding_window,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, draft block size = target block size

Comment thread atom/model_engine/model_runner.py Outdated
block_bytes = self.attn_metadata_builder.compute_block_bytes()
# Eagle3 draft layer uses a separate non-MLA KV cache; account for it
# alongside the target model's per-block bytes.
if getattr(self, "eagle3_mode", False) and hasattr(self, "drafter"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...

Comment thread atom/config.py
self.use_aux_hidden_state = eagle_cfg.get("use_aux_hidden_state", False)
if self.use_aux_hidden_state and not self.eagle3_aux_layer_ids:
self.eagle3_aux_layer_ids = eagle_cfg.get(
"eagle_aux_hidden_state_layer_ids", []
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, add comment

Comment thread atom/model_ops/paged_attention.py Outdated
dtype=dtype,
q_norm=q_norm,
k_norm=k_norm,
block_size=block_size,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...

@yhl-amd yhl-amd force-pushed the support_kimi_draft branch 3 times, most recently from 0c6c01d to 9a77468 Compare May 8, 2026 08:15
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft):

- Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV,
  independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3
  checkpoint
- Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol
  (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor)
  for the draft's independent non-MLA KV cache, attached to the runner from
  EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner
  delegates KV pool sizing, allocation, and per-module binding through this
  hook with no eagle3-specific code in the runner KV path
- Aux hidden state pipeline: target forward returns
  (hidden, aux_hidden_states), captured through CUDAGraph via
  graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as
  input
- SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP
  branching at construction time; fail-fast if draft is MLA
- Scheduler: spec_stats only updated when speculation actually ran
  (matches vLLM's gating)
- propose: draft-perspective predicate `draft_uses_mha = hasattr(runner,
  "eagle3_draft_builder")` drives both the metadata-flow special-cases
  (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft
  return value); is_eagle3 string comparison is gone from the hot path

Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot:
acceptance 67.85%, accuracy 93.78%.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@yhl-amd yhl-amd force-pushed the support_kimi_draft branch from 9a77468 to f3f85f2 Compare May 8, 2026 08:26
yhl-amd and others added 5 commits May 8, 2026 18:18
Reuses the base Kimi-K2.5-MXFP4 model + lightseekorg/kimi-k2.5-eagle3
draft, runs at TP=8 (Eagle3 draft KV needs full 8-rank sharding) under
nightly schedule. Local case_verify_v9_gluon measured GSM8K 5-shot
flexible-extract = 0.9257 (vLLM = 0.9280); threshold set to 0.91 with
~1.5pp noise headroom.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@valarLip valarLip merged commit 679422d into ROCm:main May 10, 2026
25 of 28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: ROCm Kimi K2.5 EAGLE3 MTP heads

2 participants