Skip to content

Fuse reshape_and_cache + paged_attention into a single MLX primitive#225

Merged
ericcurtin merged 8 commits intovllm-project:mainfrom
WindChimeRan:feat/fused-reshape-attention-primitive
Apr 6, 2026
Merged

Fuse reshape_and_cache + paged_attention into a single MLX primitive#225
ericcurtin merged 8 commits intovllm-project:mainfrom
WindChimeRan:feat/fused-reshape-attention-primitive

Conversation

@WindChimeRan
Copy link
Copy Markdown
Collaborator

@WindChimeRan WindChimeRan commented Apr 3, 2026

bench_primitive_comparison

Builds on the spike in #209. Related: #188.

What

Replace the eager reshape_and_cache + paged_attention_v2_online dispatch with:

  1. MLX-native scatter for cache writes - pure functional, graph-tracked, donation-eligible. Replaces the custom reshape_and_cache Metal kernel in the production path.
  2. PagedAttentionPrimitive for attention - a read-only primitive that dispatches the paged attention Metal kernel lazily.

Both operations are fully lazy. No per-layer mx.eval or mx.synchronize. The entire 28-layer model builds one lazy graph, evaluated once by the model runner.

Bug fix: custom primitives must not call add_temporary inside eval_gpu. MLX's add_temporary removes buffer pointers from the command encoder's fence tracking, breaking cross-command-buffer synchronization when the graph is evaluated lazily. The fix: from_primitive=true skips all add_temporary calls. MLX's evaluator already manages array lifetimes via the completion handler. This matches the pattern in MLX's official axpby extension example.

Why

The original eager path does 1 mx.eval + 1 mx.synchronize per layer, each a CPU-GPU sync point. With 28 layers, that is 56 sync points per decode step.

The primitive path eliminates all of them. The scatter participates in MLX's lazy graph, and the attention primitive dispatches correctly across command buffer boundaries thanks to the add_temporary fix.

Test

6/6 deterministic golden tests (bit-exact). 362/362 broader suite.

Future Work

Benchmark (sonnet 1024+128, 100 prompts, concurrency 8, 5 warmups)

Metric main this PR Change
Duration (s) 288.71 257.01 -11.0%
Output tok/s 44.34 49.80 +12.3%
Total tok/s 394.96 443.68 +12.3%
Mean TTFT (ms) 4810.26 4771.41 -0.8%
Mean TPOT (ms) 139.98 121.47 -13.2%
P99 TPOT (ms) 169.16 149.75 -11.5%
Full benchmark output

main (paged attention, eager dispatch)

============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             8
Request rate configured (RPS):           10.00
Benchmark duration (s):                  288.71
Total input tokens:                      101230
Total generated tokens:                  12800
Request throughput (req/s):              0.35
Output token throughput (tok/s):         44.34
Peak output token throughput (tok/s):    112.00
Peak concurrent requests:                11.00
Total token throughput (tok/s):          394.96
---------------Time to First Token----------------
Mean TTFT (ms):                          4810.26
Median TTFT (ms):                        4865.42
P99 TTFT (ms):                           10145.45
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          139.98
Median TPOT (ms):                        139.36
P99 TPOT (ms):                           169.16
---------------Inter-token Latency----------------
Mean ITL (ms):                           139.98
Median ITL (ms):                         79.76
P99 ITL (ms):                            2661.48
==================================================

this PR (MLX scatter + paged attention primitive, fully lazy)

============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             8
Request rate configured (RPS):           10.00
Benchmark duration (s):                  257.01
Total input tokens:                      101230
Total generated tokens:                  12800
Request throughput (req/s):              0.39
Output token throughput (tok/s):         49.80
Peak output token throughput (tok/s):    120.00
Peak concurrent requests:                11.00
Total token throughput (tok/s):          443.68
---------------Time to First Token----------------
Mean TTFT (ms):                          4771.41
Median TTFT (ms):                        5165.51
P99 TTFT (ms):                           9396.19
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          121.47
Median TPOT (ms):                        118.78
P99 TPOT (ms):                           149.75
---------------Inter-token Latency----------------
Mean ITL (ms):                           121.47
Median ITL (ms):                         70.34
P99 ITL (ms):                            2660.84
==================================================
Benchmark config
  • Model: Qwen/Qwen3-0.6B
  • Dataset: sonnet (1024 input + 128 output)
  • Prompts: 100, rate: 10, concurrency: 8, warmups: 5
  • Memory fraction: 0.3 (paged path)
  • Hardware: Apple M1 Pro, 32 GB RAM

Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan force-pushed the feat/fused-reshape-attention-primitive branch from 96bd651 to 7a2ec93 Compare April 3, 2026 04:45
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan changed the title Fused Page Attention Primitive to save CPU-GPU sync Fuse reshape_and_cache + paged_attention into a single MLX primitive Apr 3, 2026
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan
Copy link
Copy Markdown
Collaborator Author

WindChimeRan commented Apr 3, 2026

Design: functional cache writes via MLX scatter

From a functional programming semantics perspective, reshape_and_cache doesn't mutate the "whole cache." It writes to specific slots determined by slot_mapping. Those slots (block_idx, block_offset) are disjoint from all other slots - no two tokens write to the same slot. Conceptually:

new_cache = old_cache                       # copy the whole thing
new_cache[slot_mapping[i]] = new_kv[i]      # write to specific slots

This is exactly what mx.scatter or slice assignment does - a pure functional operation. The result is a "new" array. MLX knows how to handle this: When the old cache reference has use_count == 1 (nobody else holds it), MLX can donate the buffer - the "copy" reuses the same physical memory. Zero allocation, zero memcpy.

The original reshape_and_cache Metal kernel was an optimization that mutated the cache buffer in-place. This made it a side effect invisible to MLX's computation graph, forcing per-layer mx.eval + mx.synchronize to ensure correctness. By replacing it with MLX's native scatter, the cache write becomes a proper graph node. MLX tracks the dependency, handles buffer donation, and sequences it correctly across command buffer boundaries - no explicit sync needed.

Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan marked this pull request as ready for review April 3, 2026 18:11
@WindChimeRan
Copy link
Copy Markdown
Collaborator Author

WindChimeRan commented Apr 3, 2026

@ericcurtin @Kingwl

This is the last piece of the paged attention puzzle! Now only one MLX's lazy graph, no unnecessary sync.

@ericcurtin
Copy link
Copy Markdown
Collaborator

Cache rebind may break shared references (prefix caching), no test coverage for primitive path, mx.array(0)+overwrite_descriptor is fragile, UnaryPrimitive with 6 inputs is semantically misleading

@WindChimeRan WindChimeRan marked this pull request as draft April 6, 2026 05:23
@WindChimeRan
Copy link
Copy Markdown
Collaborator Author

@ericcurtin

UnaryPrimitive with 6 inputs is semantically misleading

This is from mlx, UnaryPrimitive means single-output, not single-input.

@WindChimeRan
Copy link
Copy Markdown
Collaborator Author

mx.array(0)+overwrite_descriptor is fragile

Agree it is indeed fragile, but I don't have a better way for now. The new test should cover it if the upstream mlx update break it.

@WindChimeRan
Copy link
Copy Markdown
Collaborator Author

Cache rebind may break shared references (prefix caching)

prefix caching should be applied to block index level, not the array reference level. This should be fine.

@WindChimeRan WindChimeRan force-pushed the feat/fused-reshape-attention-primitive branch 2 times, most recently from 3416d33 to fc94266 Compare April 6, 2026 06:50
WindChimeRan and others added 3 commits April 6, 2026 01:53
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Resolve conflict in paged_ops.cpp: keep both paged_attention_primitive
(ours) and gdn_linear_attention (upstream vllm-project#226) bindings.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan force-pushed the feat/fused-reshape-attention-primitive branch from fc94266 to 2f691eb Compare April 6, 2026 06:54
@WindChimeRan WindChimeRan marked this pull request as ready for review April 6, 2026 06:56
@WindChimeRan
Copy link
Copy Markdown
Collaborator Author

@ericcurtin added some tests. Request for review

@ericcurtin ericcurtin merged commit f518143 into vllm-project:main Apr 6, 2026
5 checks passed
Alex-ai-future pushed a commit to Alex-ai-future/vllm-metal that referenced this pull request Apr 8, 2026
…llm-project#225)

<img width="2493" height="927" alt="bench_primitive_comparison"
src="https://github.com/user-attachments/assets/54a5f038-f16c-4bbe-a96d-343dcfae04fa"
/>




Builds on the spike in vllm-project#209. Related: vllm-project#188.

### What

Replace the eager `reshape_and_cache` + `paged_attention_v2_online`
dispatch with:

1. **MLX-native scatter** for cache writes - pure functional,
graph-tracked, donation-eligible. Replaces the custom
`reshape_and_cache` Metal kernel in the production path.
2. **`PagedAttentionPrimitive`** for attention - a read-only primitive
that dispatches the paged attention Metal kernel lazily.

Both operations are fully lazy. No per-layer `mx.eval` or
`mx.synchronize`. The entire 28-layer model builds one lazy graph,
evaluated once by the model runner.

**Bug fix**: custom primitives must not call `add_temporary` inside
`eval_gpu`. MLX's `add_temporary` removes buffer pointers from the
command encoder's fence tracking, breaking cross-command-buffer
synchronization when the graph is evaluated lazily. The fix:
`from_primitive=true` skips all `add_temporary` calls. MLX's evaluator
already manages array lifetimes via the completion handler. This matches
the pattern in MLX's official `axpby` extension example.

### Why

The original eager path does **1 `mx.eval` + 1 `mx.synchronize`** per
layer, each a CPU-GPU sync point. With 28 layers, that is 56 sync points
per decode step.

The primitive path eliminates all of them. The scatter participates in
MLX's lazy graph, and the attention primitive dispatches correctly
across command buffer boundaries thanks to the `add_temporary` fix.


### Test

6/6 deterministic golden tests (bit-exact). 362/362 broader suite.

### Future Work

- Wire partitioned attention (vllm-project#181) into the primitive path
- Clean up dead eager code paths (`reshape_and_cache`,
`metal_unified_attention`, etc.)


### Benchmark (sonnet 1024+128, 100 prompts, concurrency 8, 5 warmups)

| Metric | main | this PR | Change |
|---|---:|---:|---:|
| Duration (s) | 288.71 | 257.01 | **-11.0%** |
| Output tok/s | 44.34 | 49.80 | **+12.3%** |
| Total tok/s | 394.96 | 443.68 | **+12.3%** |
| Mean TTFT (ms) | 4810.26 | 4771.41 | -0.8% |
| Mean TPOT (ms) | 139.98 | 121.47 | **-13.2%** |
| P99 TPOT (ms) | 169.16 | 149.75 | **-11.5%** |

<details>
<summary>Full benchmark output</summary>

**main (paged attention, eager dispatch)**
```
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             8
Request rate configured (RPS):           10.00
Benchmark duration (s):                  288.71
Total input tokens:                      101230
Total generated tokens:                  12800
Request throughput (req/s):              0.35
Output token throughput (tok/s):         44.34
Peak output token throughput (tok/s):    112.00
Peak concurrent requests:                11.00
Total token throughput (tok/s):          394.96
---------------Time to First Token----------------
Mean TTFT (ms):                          4810.26
Median TTFT (ms):                        4865.42
P99 TTFT (ms):                           10145.45
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          139.98
Median TPOT (ms):                        139.36
P99 TPOT (ms):                           169.16
---------------Inter-token Latency----------------
Mean ITL (ms):                           139.98
Median ITL (ms):                         79.76
P99 ITL (ms):                            2661.48
==================================================
```

**this PR (MLX scatter + paged attention primitive, fully lazy)**
```
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             8
Request rate configured (RPS):           10.00
Benchmark duration (s):                  257.01
Total input tokens:                      101230
Total generated tokens:                  12800
Request throughput (req/s):              0.39
Output token throughput (tok/s):         49.80
Peak output token throughput (tok/s):    120.00
Peak concurrent requests:                11.00
Total token throughput (tok/s):          443.68
---------------Time to First Token----------------
Mean TTFT (ms):                          4771.41
Median TTFT (ms):                        5165.51
P99 TTFT (ms):                           9396.19
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          121.47
Median TPOT (ms):                        118.78
P99 TPOT (ms):                           149.75
---------------Inter-token Latency----------------
Mean ITL (ms):                           121.47
Median ITL (ms):                         70.34
P99 ITL (ms):                            2660.84
==================================================
```

</details>

<details>
<summary>Benchmark config</summary>

- Model: Qwen/Qwen3-0.6B
- Dataset: sonnet (1024 input + 128 output)
- Prompts: 100, rate: 10, concurrency: 8, warmups: 5
- Memory fraction: 0.3 (paged path)
- Hardware: Apple M1 Pro, 32 GB RAM

</details>

---------

Signed-off-by: ran <hzz5361@psu.edu>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants