Skip to content

refactor: delegate ATOM KV cache subsystem to attention builders#659

Merged
valarLip merged 2 commits intomainfrom
refactor/per-req-cache-abstraction
Apr 29, 2026
Merged

refactor: delegate ATOM KV cache subsystem to attention builders#659
valarLip merged 2 commits intomainfrom
refactor/per-req-cache-abstraction

Conversation

@valarLip
Copy link
Copy Markdown
Collaborator

Summary

Generalize the GDN per-request state decoupling (#602) into a complete model-agnostic KV cache abstraction owned by the AttentionMetadataBuilder hierarchy. ModelRunner becomes blind to attention type — it walks modules and dispatches via builder hooks; per-attention-type tensor layouts (MLA 576-dim packed, GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2 indexer cache, GDN per-req mamba state) all live next to their respective builder.

ModelRunner net: -526 LOC. The if/elif chains over use_mla / is_qwen_next / is_mimo_v2 / is_deepseek_v32 in _compute_block_bytes, allocate_kv_cache, and the binding loop are all gone. Future stateful attentions (e.g. DeepseekV4 ring buffer + compressor state) plug in by subclassing AttentionMetadataBuilder without touching Scheduler / BlockManager / ModelRunner.

New AttentionMetadataBuilder hooks (defaults are no-ops)

Hook Purpose
compute_per_req_cache_bytes() / slots_per_req() bytes/slot for the per-request state pool
allocate_per_req_cache(num_slots) dict of named per-request state tensors
compute_block_bytes() per-block bytes for the KV pool budget
allocate_kv_cache_tensors(num_kv_heads, num_draft_layers) dict of named primary KV tensors (kv_cache, kv_scale, index_cache, aligned_index_dim, _kv_layer_cache_store)
build_kv_cache_tensor(layer_id, module) vLLM-style KVCacheTensor for one module, or None if foreign type; owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache)

Builder overrides

  • AiterAttentionMetadataBuilder — split-K/V MHA + MiMo-V2 per-module
  • AiterMLAMetadataBuilder — 576-dim MLA + V3.2 indexer
  • GDNAttentionMetadataBuilder — hybrid full-attn rows + GDN mamba slot pool; chains super() for MHA modules in hybrid models. Absorbs the formerly-runner-owned gated_delta_net_state_shape/_dtypes helpers and the side-effect init of full_attention_interval / num_full_attn / num_gdn_attn_state.

Naming: group vs. slot

Distinguishes group (per-request unit) from slot (raw tensor index). One group occupies slots_per_req() contiguous slots in the underlying tensor.

Old New
Sequence.mamba_state_slot .per_req_cache_group
seq.mamba_enabled .has_per_req_cache
batch.mamba_state_slots .per_req_cache_groups
BlockManager.mamba_* .per_req_cache_* (free pool, accounting)
config.mamba_equiv_per_req .per_req_cache_equiv_blocks
config.num_mamba_groups .num_per_req_cache_groups
ModelRunner.max_mamba_slots .max_per_req_cache_slots (tensor dim)

Removed

  • ModelRunner._compute_mamba_per_slot_bytes (moved to GDNAttentionMetadataBuilder.compute_per_req_cache_bytes)
  • ModelRunner.gated_delta_net_state_shape / _dtypes (moved to GDNAttentionMetadataBuilder._state_shape / _state_dtypes)
  • The 4-way if/elif dispatch in _compute_block_bytes, allocate_kv_cache, and binding loop

Sanity check

ModelRunner.__init__ now asserts that any builder returning compute_per_req_cache_bytes() > 0 has its model_type registered in InputOutputProcessor._per_req_cache_model_types(), catching the silent-corruption misconfiguration where a stateful attention is added but Sequence-construction never gets the has_per_req_cache=True flag.

Test plan

  • tests/test_per_req_cache_decoupling.py: 24/24 pass
  • Core suite (block_manager, sequence, scheduler, request, io_processor_fanout, prefix_cache_accuracy): 118/118 pass
  • Qwen3.5-397B-A17B-FP8 tp=4 simple_inference: 4-prompt completion quality unchanged
  • Qwen3.5-397B-A17B-FP8 tp=4 GSM8K (5-shot, NUM_CONCURRENT=64):
  • Docs synced (scheduling_kv_cache_guide.md, architecture_guide.md, configuration_guide.md, model_support_guide.md)

Repro

# Unit tests
python -m pytest tests/test_per_req_cache_decoupling.py tests/test_block_manager.py \
  tests/test_sequence.py tests/test_scheduler.py tests/test_request.py \
  tests/test_io_processor_fanout.py tests/test_prefix_cache_accuracy.py -q

# Qwen3.5 smoke (4 prompts, deterministic)
bash /app/logs_claude/run_simple_inference_streamed.sh /data/Qwen3.5-397B-A17B-FP8 4 --temperature 0.0

# Qwen3.5 GSM8K
bash /app/logs_claude/start_atom_server.sh /data/Qwen3.5-397B-A17B-FP8 4 8001
NUM_CONCURRENT=64 bash /app/logs_claude/run_gsm8k_eval.sh /data/Qwen3.5-397B-A17B-FP8 8001 5

Generalize the GDN per-request state decoupling (#602) into a complete
model-agnostic KV abstraction owned by the AttentionMetadataBuilder
hierarchy. ModelRunner is now blind to attention type — it walks modules
and dispatches; per-attention-type tensor layouts (MLA 576-dim packed,
GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2
indexer cache, GDN per-req mamba state) all live next to their
respective builder.

ModelRunner net: -526 LOC. The if/elif chains over use_mla /
is_qwen_next / is_mimo_v2 / is_deepseek_v32 in _compute_block_bytes,
allocate_kv_cache, and the binding loop are all gone. Future stateful
attentions (DeepseekV4 ring buffer + compressor state) plug in by
subclassing AttentionMetadataBuilder without touching scheduler /
block_manager / ModelRunner.

New AttentionMetadataBuilder hooks (defaults are no-ops):
  - compute_per_req_cache_bytes() / slots_per_req()
      bytes/slot for the per-request state pool
  - allocate_per_req_cache(num_slots)
      dict of named per-request state tensors
  - compute_block_bytes()
      per-block bytes for the KV pool budget
  - allocate_kv_cache_tensors(num_kv_heads, num_draft_layers)
      dict of named primary KV cache tensors (kv_cache, kv_scale,
      index_cache, aligned_index_dim, _kv_layer_cache_store)
  - build_kv_cache_tensor(layer_id, module)
      vLLM-style KVCacheTensor for one module, or None if foreign type;
      owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache)

Builder overrides:
  - AiterAttentionMetadataBuilder: split-K/V MHA + MiMo-V2 per-module
  - AiterMLAMetadataBuilder: 576-dim MLA + V3.2 indexer
  - GDNAttentionMetadataBuilder: hybrid full-attn rows + GDN mamba slot
    pool; chains super() for MHA modules in hybrid models. Absorbs the
    formerly-runner-owned gated_delta_net_state_shape/dtypes helpers
    and the side-effect init of full_attention_interval / num_full_attn
    / num_gdn_attn_state.

Naming distinguishes group (per-request unit) from slot (raw tensor
index). One group occupies `slots_per_req()` contiguous slots in the
underlying tensor:
  Sequence.mamba_state_slot     -> .per_req_cache_group
  seq.mamba_enabled             -> .has_per_req_cache
  batch.mamba_state_slots       -> .per_req_cache_groups
  BlockManager.mamba_*          -> .per_req_cache_*  (free pool, accounting)
  config.mamba_equiv_per_req    -> .per_req_cache_equiv_blocks
  config.num_mamba_groups       -> .num_per_req_cache_groups
  ModelRunner.max_mamba_slots   -> .max_per_req_cache_slots  (tensor dim)

Removed (moved to builders):
  ModelRunner._compute_mamba_per_slot_bytes
  ModelRunner.gated_delta_net_state_shape / _dtypes

Sanity check: ModelRunner.__init__ now asserts that any builder
returning compute_per_req_cache_bytes() > 0 has its model_type
registered in InputOutputProcessor._per_req_cache_model_types(),
catching the silent-corruption misconfiguration where a stateful
attention is added but Sequence-construction never gets the
has_per_req_cache=True flag.

Verified:
  - tests/test_per_req_cache_decoupling.py: 24/24 pass
  - core suite (block_manager, sequence, scheduler, request,
    io_processor_fanout, prefix_cache_accuracy): 118/118 pass
  - Qwen3.5-397B-A17B-FP8 tp=4 simple_inference: 4-prompt completion
    quality unchanged
  - Qwen3.5-397B-A17B-FP8 tp=4 GSM8K (5-shot, 64 concurrent):
      flexible-extract = 0.8757 +/- 0.0091  (baseline 0.8711 from #602)
      strict-match     = 0.8605 +/- 0.0095
Copilot AI review requested due to automatic review settings April 28, 2026 16:11
@valarLip valarLip merged commit 99f0990 into main Apr 29, 2026
52 of 54 checks passed
@valarLip valarLip deleted the refactor/per-req-cache-abstraction branch April 29, 2026 08:37
yhl-amd added a commit to yhl-amd/ATOM that referenced this pull request May 8, 2026
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft):

- Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV,
  independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3
  checkpoint
- Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol
  (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor)
  for the draft's independent non-MLA KV cache, attached to the runner from
  EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner
  delegates KV pool sizing, allocation, and per-module binding through this
  hook with no eagle3-specific code in the runner KV path
- Aux hidden state pipeline: target forward returns
  (hidden, aux_hidden_states), captured through CUDAGraph via
  graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as
  input
- SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP
  branching at construction time; fail-fast if draft is MLA
- Scheduler: spec_stats only updated when speculation actually ran
  (matches vLLM's gating)
- propose: draft-perspective predicate `draft_uses_mha = hasattr(runner,
  "eagle3_draft_builder")` drives both the metadata-flow special-cases
  (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft
  return value); is_eagle3 string comparison is gone from the hot path

Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot:
acceptance 67.85%, accuracy 93.78%.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
yhl-amd added a commit to yhl-amd/ATOM that referenced this pull request May 8, 2026
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft):

- Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV,
  independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3
  checkpoint
- Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol
  (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor)
  for the draft's independent non-MLA KV cache, attached to the runner from
  EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner
  delegates KV pool sizing, allocation, and per-module binding through this
  hook with no eagle3-specific code in the runner KV path
- Aux hidden state pipeline: target forward returns
  (hidden, aux_hidden_states), captured through CUDAGraph via
  graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as
  input
- SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP
  branching at construction time; fail-fast if draft is MLA
- Scheduler: spec_stats only updated when speculation actually ran
  (matches vLLM's gating)
- propose: draft-perspective predicate `draft_uses_mha = hasattr(runner,
  "eagle3_draft_builder")` drives both the metadata-flow special-cases
  (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft
  return value); is_eagle3 string comparison is gone from the hot path

Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot:
acceptance 67.85%, accuracy 93.78%.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
yhl-amd added a commit to yhl-amd/ATOM that referenced this pull request May 8, 2026
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft):

- Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV,
  independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3
  checkpoint
- Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol
  (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor)
  for the draft's independent non-MLA KV cache, attached to the runner from
  EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner
  delegates KV pool sizing, allocation, and per-module binding through this
  hook with no eagle3-specific code in the runner KV path
- Aux hidden state pipeline: target forward returns
  (hidden, aux_hidden_states), captured through CUDAGraph via
  graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as
  input
- SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP
  branching at construction time; fail-fast if draft is MLA
- Scheduler: spec_stats only updated when speculation actually ran
  (matches vLLM's gating)
- propose: draft-perspective predicate `draft_uses_mha = hasattr(runner,
  "eagle3_draft_builder")` drives both the metadata-flow special-cases
  (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft
  return value); is_eagle3 string comparison is gone from the hot path

Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot:
acceptance 67.85%, accuracy 93.78%.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
yhl-amd added a commit to yhl-amd/ATOM that referenced this pull request May 8, 2026
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft):

- Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV,
  independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3
  checkpoint
- Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol
  (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor)
  for the draft's independent non-MLA KV cache, attached to the runner from
  EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner
  delegates KV pool sizing, allocation, and per-module binding through this
  hook with no eagle3-specific code in the runner KV path
- Aux hidden state pipeline: target forward returns
  (hidden, aux_hidden_states), captured through CUDAGraph via
  graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as
  input
- SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP
  branching at construction time; fail-fast if draft is MLA
- Scheduler: spec_stats only updated when speculation actually ran
  (matches vLLM's gating)
- propose: draft-perspective predicate `draft_uses_mha = hasattr(runner,
  "eagle3_draft_builder")` drives both the metadata-flow special-cases
  (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft
  return value); is_eagle3 string comparison is gone from the hot path

Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot:
acceptance 67.85%, accuracy 93.78%.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant