refactor: delegate ATOM KV cache subsystem to attention builders#659
Merged
refactor: delegate ATOM KV cache subsystem to attention builders#659
Conversation
Generalize the GDN per-request state decoupling (#602) into a complete model-agnostic KV abstraction owned by the AttentionMetadataBuilder hierarchy. ModelRunner is now blind to attention type — it walks modules and dispatches; per-attention-type tensor layouts (MLA 576-dim packed, GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2 indexer cache, GDN per-req mamba state) all live next to their respective builder. ModelRunner net: -526 LOC. The if/elif chains over use_mla / is_qwen_next / is_mimo_v2 / is_deepseek_v32 in _compute_block_bytes, allocate_kv_cache, and the binding loop are all gone. Future stateful attentions (DeepseekV4 ring buffer + compressor state) plug in by subclassing AttentionMetadataBuilder without touching scheduler / block_manager / ModelRunner. New AttentionMetadataBuilder hooks (defaults are no-ops): - compute_per_req_cache_bytes() / slots_per_req() bytes/slot for the per-request state pool - allocate_per_req_cache(num_slots) dict of named per-request state tensors - compute_block_bytes() per-block bytes for the KV pool budget - allocate_kv_cache_tensors(num_kv_heads, num_draft_layers) dict of named primary KV cache tensors (kv_cache, kv_scale, index_cache, aligned_index_dim, _kv_layer_cache_store) - build_kv_cache_tensor(layer_id, module) vLLM-style KVCacheTensor for one module, or None if foreign type; owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache) Builder overrides: - AiterAttentionMetadataBuilder: split-K/V MHA + MiMo-V2 per-module - AiterMLAMetadataBuilder: 576-dim MLA + V3.2 indexer - GDNAttentionMetadataBuilder: hybrid full-attn rows + GDN mamba slot pool; chains super() for MHA modules in hybrid models. Absorbs the formerly-runner-owned gated_delta_net_state_shape/dtypes helpers and the side-effect init of full_attention_interval / num_full_attn / num_gdn_attn_state. Naming distinguishes group (per-request unit) from slot (raw tensor index). One group occupies `slots_per_req()` contiguous slots in the underlying tensor: Sequence.mamba_state_slot -> .per_req_cache_group seq.mamba_enabled -> .has_per_req_cache batch.mamba_state_slots -> .per_req_cache_groups BlockManager.mamba_* -> .per_req_cache_* (free pool, accounting) config.mamba_equiv_per_req -> .per_req_cache_equiv_blocks config.num_mamba_groups -> .num_per_req_cache_groups ModelRunner.max_mamba_slots -> .max_per_req_cache_slots (tensor dim) Removed (moved to builders): ModelRunner._compute_mamba_per_slot_bytes ModelRunner.gated_delta_net_state_shape / _dtypes Sanity check: ModelRunner.__init__ now asserts that any builder returning compute_per_req_cache_bytes() > 0 has its model_type registered in InputOutputProcessor._per_req_cache_model_types(), catching the silent-corruption misconfiguration where a stateful attention is added but Sequence-construction never gets the has_per_req_cache=True flag. Verified: - tests/test_per_req_cache_decoupling.py: 24/24 pass - core suite (block_manager, sequence, scheduler, request, io_processor_fanout, prefix_cache_accuracy): 118/118 pass - Qwen3.5-397B-A17B-FP8 tp=4 simple_inference: 4-prompt completion quality unchanged - Qwen3.5-397B-A17B-FP8 tp=4 GSM8K (5-shot, 64 concurrent): flexible-extract = 0.8757 +/- 0.0091 (baseline 0.8711 from #602) strict-match = 0.8605 +/- 0.0095
valarLip
added a commit
that referenced
this pull request
Apr 30, 2026
yhl-amd
added a commit
to yhl-amd/ATOM
that referenced
this pull request
May 8, 2026
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft): - Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV, independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3 checkpoint - Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor) for the draft's independent non-MLA KV cache, attached to the runner from EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner delegates KV pool sizing, allocation, and per-module binding through this hook with no eagle3-specific code in the runner KV path - Aux hidden state pipeline: target forward returns (hidden, aux_hidden_states), captured through CUDAGraph via graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as input - SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP branching at construction time; fail-fast if draft is MLA - Scheduler: spec_stats only updated when speculation actually ran (matches vLLM's gating) - propose: draft-perspective predicate `draft_uses_mha = hasattr(runner, "eagle3_draft_builder")` drives both the metadata-flow special-cases (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft return value); is_eagle3 string comparison is gone from the hot path Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot: acceptance 67.85%, accuracy 93.78%. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
yhl-amd
added a commit
to yhl-amd/ATOM
that referenced
this pull request
May 8, 2026
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft): - Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV, independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3 checkpoint - Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor) for the draft's independent non-MLA KV cache, attached to the runner from EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner delegates KV pool sizing, allocation, and per-module binding through this hook with no eagle3-specific code in the runner KV path - Aux hidden state pipeline: target forward returns (hidden, aux_hidden_states), captured through CUDAGraph via graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as input - SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP branching at construction time; fail-fast if draft is MLA - Scheduler: spec_stats only updated when speculation actually ran (matches vLLM's gating) - propose: draft-perspective predicate `draft_uses_mha = hasattr(runner, "eagle3_draft_builder")` drives both the metadata-flow special-cases (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft return value); is_eagle3 string comparison is gone from the hot path Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot: acceptance 67.85%, accuracy 93.78%. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
yhl-amd
added a commit
to yhl-amd/ATOM
that referenced
this pull request
May 8, 2026
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft): - Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV, independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3 checkpoint - Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor) for the draft's independent non-MLA KV cache, attached to the runner from EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner delegates KV pool sizing, allocation, and per-module binding through this hook with no eagle3-specific code in the runner KV path - Aux hidden state pipeline: target forward returns (hidden, aux_hidden_states), captured through CUDAGraph via graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as input - SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP branching at construction time; fail-fast if draft is MLA - Scheduler: spec_stats only updated when speculation actually ran (matches vLLM's gating) - propose: draft-perspective predicate `draft_uses_mha = hasattr(runner, "eagle3_draft_builder")` drives both the metadata-flow special-cases (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft return value); is_eagle3 string comparison is gone from the hot path Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot: acceptance 67.85%, accuracy 93.78%. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
yhl-amd
added a commit
to yhl-amd/ATOM
that referenced
this pull request
May 8, 2026
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft): - Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV, independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3 checkpoint - Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor) for the draft's independent non-MLA KV cache, attached to the runner from EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner delegates KV pool sizing, allocation, and per-module binding through this hook with no eagle3-specific code in the runner KV path - Aux hidden state pipeline: target forward returns (hidden, aux_hidden_states), captured through CUDAGraph via graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as input - SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP branching at construction time; fail-fast if draft is MLA - Scheduler: spec_stats only updated when speculation actually ran (matches vLLM's gating) - propose: draft-perspective predicate `draft_uses_mha = hasattr(runner, "eagle3_draft_builder")` drives both the metadata-flow special-cases (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft return value); is_eagle3 string comparison is gone from the hot path Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot: acceptance 67.85%, accuracy 93.78%. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Generalize the GDN per-request state decoupling (#602) into a complete model-agnostic KV cache abstraction owned by the
AttentionMetadataBuilderhierarchy. ModelRunner becomes blind to attention type — it walks modules and dispatches via builder hooks; per-attention-type tensor layouts (MLA 576-dim packed, GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2 indexer cache, GDN per-req mamba state) all live next to their respective builder.ModelRunner net: -526 LOC. The
if/elifchains overuse_mla/is_qwen_next/is_mimo_v2/is_deepseek_v32in_compute_block_bytes,allocate_kv_cache, and the binding loop are all gone. Future stateful attentions (e.g. DeepseekV4 ring buffer + compressor state) plug in by subclassingAttentionMetadataBuilderwithout touchingScheduler/BlockManager/ModelRunner.New AttentionMetadataBuilder hooks (defaults are no-ops)
compute_per_req_cache_bytes()/slots_per_req()allocate_per_req_cache(num_slots)compute_block_bytes()allocate_kv_cache_tensors(num_kv_heads, num_draft_layers)kv_cache,kv_scale,index_cache,aligned_index_dim,_kv_layer_cache_store)build_kv_cache_tensor(layer_id, module)KVCacheTensorfor one module, orNoneif foreign type; owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache)Builder overrides
AiterAttentionMetadataBuilder— split-K/V MHA + MiMo-V2 per-moduleAiterMLAMetadataBuilder— 576-dim MLA + V3.2 indexerGDNAttentionMetadataBuilder— hybrid full-attn rows + GDN mamba slot pool; chainssuper()for MHA modules in hybrid models. Absorbs the formerly-runner-ownedgated_delta_net_state_shape/_dtypeshelpers and the side-effect init offull_attention_interval/num_full_attn/num_gdn_attn_state.Naming: group vs. slot
Distinguishes group (per-request unit) from slot (raw tensor index). One group occupies
slots_per_req()contiguous slots in the underlying tensor.Sequence.mamba_state_slot.per_req_cache_groupseq.mamba_enabled.has_per_req_cachebatch.mamba_state_slots.per_req_cache_groupsBlockManager.mamba_*.per_req_cache_*(free pool, accounting)config.mamba_equiv_per_req.per_req_cache_equiv_blocksconfig.num_mamba_groups.num_per_req_cache_groupsModelRunner.max_mamba_slots.max_per_req_cache_slots(tensor dim)Removed
ModelRunner._compute_mamba_per_slot_bytes(moved toGDNAttentionMetadataBuilder.compute_per_req_cache_bytes)ModelRunner.gated_delta_net_state_shape/_dtypes(moved toGDNAttentionMetadataBuilder._state_shape/_state_dtypes)if/elifdispatch in_compute_block_bytes,allocate_kv_cache, and binding loopSanity check
ModelRunner.__init__now asserts that any builder returningcompute_per_req_cache_bytes() > 0has itsmodel_typeregistered inInputOutputProcessor._per_req_cache_model_types(), catching the silent-corruption misconfiguration where a stateful attention is added but Sequence-construction never gets thehas_per_req_cache=Trueflag.Test plan
tests/test_per_req_cache_decoupling.py: 24/24 passblock_manager,sequence,scheduler,request,io_processor_fanout,prefix_cache_accuracy): 118/118 passsimple_inference: 4-prompt completion quality unchangedflexible-extract = 0.8757 ± 0.0091(baseline0.8711from feat: decouple GDN recurrent state from KV cache block pool #602)strict-match = 0.8605 ± 0.0095scheduling_kv_cache_guide.md,architecture_guide.md,configuration_guide.md,model_support_guide.md)Repro