[Kimi] support Eagle3 speculative decoding for Kimi K2.5#631
Merged
[Kimi] support Eagle3 speculative decoding for Kimi K2.5#631
Conversation
da4a5c8 to
46da277
Compare
4c921bb to
dadefd7
Compare
valarLip
reviewed
May 7, 2026
| slot_mapping[:] = kv_indices[kv_indptr[1 : bs + 1] - 1] | ||
|
|
||
| if is_eagle3: | ||
| attn_metadata.context_lens = attn_metadata.context_lens + 1 |
| layer_num=layer_num, | ||
| prefix=f"{prefix}.attn", | ||
| rotary_emb=self.rotary_emb, | ||
| per_layer_sliding_window=sliding_window, |
Contributor
Author
There was a problem hiding this comment.
fixed, draft block size = target block size
| block_bytes = self.attn_metadata_builder.compute_block_bytes() | ||
| # Eagle3 draft layer uses a separate non-MLA KV cache; account for it | ||
| # alongside the target model's per-block bytes. | ||
| if getattr(self, "eagle3_mode", False) and hasattr(self, "drafter"): |
| self.use_aux_hidden_state = eagle_cfg.get("use_aux_hidden_state", False) | ||
| if self.use_aux_hidden_state and not self.eagle3_aux_layer_ids: | ||
| self.eagle3_aux_layer_ids = eagle_cfg.get( | ||
| "eagle_aux_hidden_state_layer_ids", [] |
| dtype=dtype, | ||
| q_norm=q_norm, | ||
| k_norm=k_norm, | ||
| block_size=block_size, |
0c6c01d to
9a77468
Compare
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft): - Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV, independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3 checkpoint - Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor) for the draft's independent non-MLA KV cache, attached to the runner from EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner delegates KV pool sizing, allocation, and per-module binding through this hook with no eagle3-specific code in the runner KV path - Aux hidden state pipeline: target forward returns (hidden, aux_hidden_states), captured through CUDAGraph via graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as input - SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP branching at construction time; fail-fast if draft is MLA - Scheduler: spec_stats only updated when speculation actually ran (matches vLLM's gating) - propose: draft-perspective predicate `draft_uses_mha = hasattr(runner, "eagle3_draft_builder")` drives both the metadata-flow special-cases (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft return value); is_eagle3 string comparison is gone from the hot path Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot: acceptance 67.85%, accuracy 93.78%. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
9a77468 to
f3f85f2
Compare
Reuses the base Kimi-K2.5-MXFP4 model + lightseekorg/kimi-k2.5-eagle3 draft, runs at TP=8 (Eagle3 draft KV needs full 8-rank sharding) under nightly schedule. Local case_verify_v9_gluon measured GSM8K 5-shot flexible-extract = 0.9257 (vLLM = 0.9280); threshold set to 0.91 with ~1.5pp noise headroom. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
valarLip
approved these changes
May 10, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Eagle3 Speculative Decoding for Kimi K2.5
Closes ROCm/ATOM#553
1. Goal
Add Eagle3 speculative decoding support for Kimi K2.5 (MLA target) in ATOM:
The draft checkpoint used in this PR is
lightseekorg/kimi-k2.5-eagle3— a 1-layer standard Llama decoder (MHA). The MLA-draft variant is not yet released upstream; the abstraction introduced here is format-agnostic and can be extended to MLA draft when available.End result: GSM8K 5-shot acceptance rate 67.85%, accuracy 93.78% on Kimi-K2.5-MXFP4.
2. Design
Eagle3 differs from the existing MTP path in ATOM along four key dimensions:
fc(hidden, aux_hidden_states)tupleThree pieces of design address these differences:
Eagle3DraftBuilder(post-refactor: delegate ATOM KV cache subsystem to attention builders #659 protocol): a sibling attention metadata builder that owns the draft's independent non-MLA KV cache through the same three hooks ModelRunner consults for the target builder (compute_block_bytes,allocate_kv_cache_tensors,build_kv_cache_tensor). Attached to the runner fromEagleProposer.__init__viarunner.eagle3_draft_builder. ModelRunner has no Eagle3-specific code in its KV path; future heterogeneous-draft methods (Medusa, ReDrafter, etc.) only need to attach their own builder under the same protocol. Cache tensors land under namespaced keys (eagle3_kv_cache/eagle3_kv_scale) so they do not collide with the target builder'skv_cache/kv_scale.runner.block_size, so draft KV layout, slot_mapping values, and block_tables are all directly compatible with what target'sprepare_decodewrote — no per-block-size translation needed. Pure-MHA layers (num_heads == num_kv_heads, satisfied by Eagle3 draft) are routed topaged_attention_triton(gluon) indispatch_backend(), bypassing the aiter ASMpa_*.cofamily (which has a 16-bitblock_idtruncation issue when KV pool exceeds 65,535 blocks). In propose, the only metadata adaptation needed for the draft's first forward is re-slicingslot_mappingtolen(input_ids)(after rejection it can be <bs * max_q_len, and the MHA cache-write kernel assertsslot_mapping <= q);context_lensis cloned (empirically +0.01 toks/fwd vs view, suspected async-ordering related) and bumped in-place per draft step.graph_aux_hidden: dict[(bs, max_q_len), list[Tensor]]stores references to the aux tensors materialized inside the captured graph. After replay, slices are taken to the actualnum_tokensfor the drafter to consume. Type dispatch uses ause_aux_hidden_state_outputsflag set at init time, avoiding runtimeisinstancechecks that would pollute graph capture.3. Comparison with the existing MTP workflow
(hidden, aux_hidden_states)graph_aux_hiddensaves aux references; sliced tonum_tokensafter replayDraftKVCacheallocated independently with its own block granularityprepare_decode; onlyslot_mappingre-sliced tolen(input_ids), andcontext_lenscloned (empirically +0.01 toks/fwd) before bumping in-place per draft step(post_norm, pre_norm); logits usepost_norm, next step usespre_normFlow diagram
4. Changes
atom/model_engine/model_runner.pyrunner.eagle3_draft_builderfor KV pool sizing / allocation / per-module binding (no Eagle3-specific KV code in ModelRunner); CUDAGraph capture handles tuple return +graph_aux_hidden; aux layer registrationatom/models/eagle3_llama.pyEagle3LlamaModel,Eagle3LlamaDecoderLayer;forwardreturns(post_norm, pre_norm);use_sliding_windowpropagated from config to attentionatom/spec_decode/eagle.pyEagle3DraftBuilderclass (KV cache subsystem under the post-#659 builder protocol); Eagle3load_modelearly-returns from a separate draft checkpoint (no embed/lm_head sharing); slot_mapping re-slice + clonedcontext_lensbefore propose loop; unpack(post_norm, pre_norm); feedpre_normto next draft step;draft_uses_mha = hasattr(runner, "eagle3_draft_builder")drives metadata-flow predicates withoutis_eagle3string comparesatom/model_engine/arg_utils.py--method eagle3,--draft-model; popped in both spec-on and spec-off branchesatom/config.pySpeculativeConfigparses Eagle3 fields; argument validation; defaultuse_aux_hidden_state=Truewheneagle_configabsent; Eagle3 architecture mapping inhf_config_override(architecture-level, not model_type-level)atom/model_engine/scheduler.pyspec_statsonly updated when speculation actually ran (drafts proposed and validated), matching vLLM's gating; previously prefill-only steps inflated acceptance metricsatom/models/deepseek_v2.py(2, N//2, N-3), aligned with vLLMatom/model_ops/linear.pyto()value conversion (only same-family dtypes such as fp8 variants keepview()bit reinterpretation)atom/models/kimi_k25.pyatom/model_ops/attention_mha.pynum_heads == num_kv_heads) routed to triton/gluon, bypassing the aiter ASMpa_*.cofamily. (Sliding-window check tightened to> 0to avoid-1falsy mis-routing.)Total: +754 / -62 across 10 files.
Integration with PR #659 (KV cache subsystem refactor)
PR #659 moved per-attention-type KV cache allocation and binding into the attention metadata builder. Eagle3 integrates as follows:
attn_metadata_builder.allocate_kv_cache_tensors()/build_kv_cache_tensor()like any other model — no Eagle3-specific code in the target builder.Eagle3DraftBuilder) that implements the same three hooks the target builder exposes.EagleProposer.__init__attaches it asrunner.eagle3_draft_builder, and ModelRunner sumscompute_block_bytes()and dispatchesallocate_kv_cache_tensors()/build_kv_cache_tensor()across both builders generically. The draft cache tensors land under namespaced keys (eagle3_kv_cache/eagle3_kv_scale) so they don't collide with the target builder'skv_cache/kv_scale.eagle3_draft_builderfirst whenmodel_name == "draft"; it returnsNonefor non-MHA modules (LayerNorm, MLP, etc.) and falls through toattn_metadata_builderlike any other module. The MLA target's attention modules never hit the draft builder because the loop only consults it undermodel_name == "draft".mtp_start_layer_idx: the target MLA builder accesses this attribute throughself.runner.mtp_start_layer_idxfor MTP cache slot ordering. Eagle3 is not MTP —Eagle3LlamaModelhas nomodel.model.mtp_start_layer_idx. Whenrunner.eagle3_draft_builderis present, ModelRunner setsmtp_start_layer_idx = hf_config.num_hidden_layers(i.e. "no MTP layers above this index"), matching the behavior when no drafter is present.5. Usage
CLI arguments
--method eagle3mtp)--draft-model <path>/data/models/kimi-k2.5-eagle3--num-speculative-tokens <int>Aux hidden-state layer ids are derived from the draft model's
eagle_configwhen present, otherwise fall back to the target model's
get_eagle3_aux_hidden_state_layers()(default(2, N//2, N-3)).OpenAI-compatible server
AITER_LOG_LEVEL=WARNING python -m atom.entrypoints.openai_server \ --model /data/models/Kimi-K2.5-MXFP4 \ --kv_cache_dtype fp8 \ -tp 8 \ --method eagle3 \ --draft-model /data/models/kimi-k2.5-eagle3 \ --num-speculative-tokens 3Offline inference
python -m atom.examples.simple_inference \ --model /data/models/Kimi-K2.5-MXFP4 \ --kv_cache_dtype fp8 \ -tp 8 \ --method eagle3 \ --draft-model /data/models/kimi-k2.5-eagle3 \ --num-speculative-tokens 3Verification
Acceptance rate and per-position distribution are emitted in the server log statistics line.
Fallback
--method,--draft-model,--num-speculative-tokens--enforce-eager6. Results
Test environment: Kimi-K2.5-MXFP4 +
lightseekorg/kimi-k2.5-eagle3, AMD MI350 × 8 (ROCm 7.2), TP=8,num_speculative_tokens=3,aux_layer_ids=(2,30,58), lm_eval GSM8K 5-shot.Acceptance distribution (
mtp_k=3): 0/1/2/3 accepted = 15.0% / 16.4% / 18.7% / 49.9%.Per-position acceptance (unconditional, drawn from the distribution above and from vLLM's
metrics.pylog line):GSM8K accuracy matches vLLM within the ±0.71% standard error, confirming numerical correctness of the draft sequence. The remaining ~5.2 pp acceptance-rate gap concentrates on positions 2 and 3, suggesting it stems from cross-step state propagation in the draft model rather than initial prediction quality. Likely contributors are bf16 numerical differences between ROCm and CUDA attention kernels and MXFP4 dequant implementation differences, both orthogonal to the scope of this PR.
7. Known limitations
kv_lora_rank, startup raisesNotImplementedError.self.drafter.model.model.<...>(e.g.mtp_start_layer_idx) must guard onhasattr(runner, "eagle3_draft_builder")and fall back tohf_config.num_hidden_layers—Eagle3LlamaModelhas nomodelattribute.