Skip to content

Refactor paged attention dispatch to support multiple attention types#201

Merged
LxYuan0420 merged 7 commits intovllm-project:mainfrom
WindChimeRan:attention_backend_dispatch
Mar 24, 2026
Merged

Refactor paged attention dispatch to support multiple attention types#201
LxYuan0420 merged 7 commits intovllm-project:mainfrom
WindChimeRan:attention_backend_dispatch

Conversation

@WindChimeRan
Copy link
Copy Markdown
Collaborator

@WindChimeRan WindChimeRan commented Mar 22, 2026

Summary

  • Extract attention-type-specific logic from the monolithic wrapper into per-type modules (attention_sdpa.py, attention_linear.py)
  • The wrapper now dispatches based on module attributes: paged varlen SDPA path for standard dot-product attention (MHA/GQA/MQA), stub for linear attention (GatedDeltaNet), and this will also unblock MLA for glm-4.7-flash and deepseek.
  • Change layer patching from single-attribute (self_attn on all layers) to per-layer lookup — required for hybrid models like Qwen3.5 where some layers use self_attn and others use linear_attn
  • Add xfail integration test for Qwen/Qwen3.5-0.8B with paged attention

No new features. This is a refactor to unblock collaboration on Qwen3.5 linear attention support.

Why attribute-based detection?

We wrap mlx_lm/mlx_vlm attention modules at runtime without modifying their source. Since we don't own the model code, we detect attention type by probing module attributes (e.g. q_proj + o_proj → SDPA, conv1d + no q_proj → linear). This should works across all known GatedDeltaNet variants (qwen3.5)

How to add a new attention type

  1. Create attention_<type>.py with is_<type>() detector + <type>_forward() implementation
  2. Add one elif branch in the wrapper dispatch (paged_attention.py)

Test plan

Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan changed the title Attention backend dispatch Refactor paged attention dispatch to support multiple attention types Mar 22, 2026
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan marked this pull request as ready for review March 23, 2026 00:50
laudney pushed a commit to mmonad/vllm-metal that referenced this pull request Mar 23, 2026
Hybrid models like Qwen3.5 use mixed cache types (ArraysCache for
linear/SSM layers + KVCache for attention layers). BatchKVCache.offset
returns mx.array but hybrid attention code uses cache.offset as a
Python int for mask slicing, causing:

    ValueError: Slice indices must be integers or None.

Detect hybrid caches at model load time via make_prompt_cache() and
fall back to sequential decode for incompatible models.

Core detection logic lives in cache_utils.py to keep model_runner.py
minimal per vllm-project#122.

NOTE: This is an interim fix for the mlx-native (non-paged) path.
The proper solution is per-layer attention dispatching (vllm-project#201) plus a
paged linear attention kernel (roadmap vllm-project#148).

Signed-off-by: Bren Mada Bowen <bowen.bren@gmail.com>
@ericcurtin
Copy link
Copy Markdown
Collaborator

Minor unused import in paged_attention.py, no unit tests for dispatch logic (only slow integration test)

Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan
Copy link
Copy Markdown
Collaborator Author

Thanks for the review! @ericcurtin

Added 4 fast tests (test_attention_dispatch.py) that verify the detection heuristics against real mlx_lm modules. qwen3.Attention, qwen3_5.DecoderLayer (both SDPA and GatedDeltaNet layers), and qwen3.Model for find_layers. No model weights needed, runs in ~2s. The full dispatch path (wrapper → forward → Metal kernel) is covered by the existing integration test.

@LxYuan0420 LxYuan0420 merged commit 19a19c4 into vllm-project:main Mar 24, 2026
5 checks passed
LxYuan0420 added a commit that referenced this pull request Mar 27, 2026
#214)

This PR is:
- To remove `find_layers_and_attr` (deprecated in #201, zero callers)
- To delete `TestBatchSplitting` which tests a local reimplementation
with no production counterpart (`_run_packed_prefill` was removed in
5bf9536)

---------

Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
LxYuan0420 pushed a commit that referenced this pull request Mar 30, 2026
…B) (#210)

## Summary

Allocate per-layer-type cache buffers for hybrid models (Qwen3.5) where
SDPA and GDN linear attention layers coexist. This is Stage B of the
Qwen3.5 roadmap (#194), builds on the dispatch refactor (Stage A, #201).

- Unwrap `text_config` in `_extract_model_args` so Qwen3.5 dimensions
are accessible
- Add `is_hybrid` detection and GDN dimensions to `_resolve_model_dims`
- Emit `FullAttentionSpec` for SDPA layers and `MambaSpec` for GDN
layers in `get_kv_cache_spec`
- Fix `get_cache_block_size_bytes` to count only SDPA layers
- Add `LinearAttentionCache` with layout `[num_blocks, Hv, Dv, Dk]` per
linear layer
- Add `HybridPagedAttentionBackend` that allocates both
`MetalPagedKVCache` (SDPA) and `LinearAttentionCache` (GDN)
- Fail fast with `RuntimeError` when hybrid model enables paged
attention (gated until Stage C)
- Only SDPA layers patched; linear layers keep original mlx_lm forward

Ref: #194 (Stage B: Hybrid cache allocation)

## Cache layout

| Layer type | Cache class | Shape per layer |
|---|---|---|
| SDPA | `MetalPagedKVCache` | `[num_blocks, block_size, num_kv_heads,
head_dim]` |
| Linear (GDN) | `LinearAttentionCache` | `[num_blocks, Hv, Dv, Dk]` |

Both caches use the same `num_blocks` from the scheduler's memory
budget. `get_kv_cache_spec` emits `MambaSpec` for GDN layers so the
scheduler groups them separately.

This PR delivers allocation infrastructure to unblock Stage C kernel
work.

---------

Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants