Fuse reshape_and_cache + paged_attention into a single MLX primitive#225
Conversation
Signed-off-by: ran <hzz5361@psu.edu>
96bd651 to
7a2ec93
Compare
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Design: functional cache writes via MLX scatterFrom a functional programming semantics perspective, This is exactly what The original |
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
|
This is the last piece of the paged attention puzzle! Now only one MLX's lazy graph, no unnecessary sync. |
|
Cache rebind may break shared references (prefix caching), no test coverage for primitive path, mx.array(0)+overwrite_descriptor is fragile, UnaryPrimitive with 6 inputs is semantically misleading |
This is from mlx, UnaryPrimitive means single-output, not single-input. |
Agree it is indeed fragile, but I don't have a better way for now. The new test should cover it if the upstream mlx update break it. |
prefix caching should be applied to block index level, not the array reference level. This should be fine. |
3416d33 to
fc94266
Compare
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Resolve conflict in paged_ops.cpp: keep both paged_attention_primitive (ours) and gdn_linear_attention (upstream vllm-project#226) bindings. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: ran <hzz5361@psu.edu>
fc94266 to
2f691eb
Compare
|
@ericcurtin added some tests. Request for review |
…llm-project#225) <img width="2493" height="927" alt="bench_primitive_comparison" src="https://github.com/user-attachments/assets/54a5f038-f16c-4bbe-a96d-343dcfae04fa" /> Builds on the spike in vllm-project#209. Related: vllm-project#188. ### What Replace the eager `reshape_and_cache` + `paged_attention_v2_online` dispatch with: 1. **MLX-native scatter** for cache writes - pure functional, graph-tracked, donation-eligible. Replaces the custom `reshape_and_cache` Metal kernel in the production path. 2. **`PagedAttentionPrimitive`** for attention - a read-only primitive that dispatches the paged attention Metal kernel lazily. Both operations are fully lazy. No per-layer `mx.eval` or `mx.synchronize`. The entire 28-layer model builds one lazy graph, evaluated once by the model runner. **Bug fix**: custom primitives must not call `add_temporary` inside `eval_gpu`. MLX's `add_temporary` removes buffer pointers from the command encoder's fence tracking, breaking cross-command-buffer synchronization when the graph is evaluated lazily. The fix: `from_primitive=true` skips all `add_temporary` calls. MLX's evaluator already manages array lifetimes via the completion handler. This matches the pattern in MLX's official `axpby` extension example. ### Why The original eager path does **1 `mx.eval` + 1 `mx.synchronize`** per layer, each a CPU-GPU sync point. With 28 layers, that is 56 sync points per decode step. The primitive path eliminates all of them. The scatter participates in MLX's lazy graph, and the attention primitive dispatches correctly across command buffer boundaries thanks to the `add_temporary` fix. ### Test 6/6 deterministic golden tests (bit-exact). 362/362 broader suite. ### Future Work - Wire partitioned attention (vllm-project#181) into the primitive path - Clean up dead eager code paths (`reshape_and_cache`, `metal_unified_attention`, etc.) ### Benchmark (sonnet 1024+128, 100 prompts, concurrency 8, 5 warmups) | Metric | main | this PR | Change | |---|---:|---:|---:| | Duration (s) | 288.71 | 257.01 | **-11.0%** | | Output tok/s | 44.34 | 49.80 | **+12.3%** | | Total tok/s | 394.96 | 443.68 | **+12.3%** | | Mean TTFT (ms) | 4810.26 | 4771.41 | -0.8% | | Mean TPOT (ms) | 139.98 | 121.47 | **-13.2%** | | P99 TPOT (ms) | 169.16 | 149.75 | **-11.5%** | <details> <summary>Full benchmark output</summary> **main (paged attention, eager dispatch)** ``` ============ Serving Benchmark Result ============ Successful requests: 100 Failed requests: 0 Maximum request concurrency: 8 Request rate configured (RPS): 10.00 Benchmark duration (s): 288.71 Total input tokens: 101230 Total generated tokens: 12800 Request throughput (req/s): 0.35 Output token throughput (tok/s): 44.34 Peak output token throughput (tok/s): 112.00 Peak concurrent requests: 11.00 Total token throughput (tok/s): 394.96 ---------------Time to First Token---------------- Mean TTFT (ms): 4810.26 Median TTFT (ms): 4865.42 P99 TTFT (ms): 10145.45 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 139.98 Median TPOT (ms): 139.36 P99 TPOT (ms): 169.16 ---------------Inter-token Latency---------------- Mean ITL (ms): 139.98 Median ITL (ms): 79.76 P99 ITL (ms): 2661.48 ================================================== ``` **this PR (MLX scatter + paged attention primitive, fully lazy)** ``` ============ Serving Benchmark Result ============ Successful requests: 100 Failed requests: 0 Maximum request concurrency: 8 Request rate configured (RPS): 10.00 Benchmark duration (s): 257.01 Total input tokens: 101230 Total generated tokens: 12800 Request throughput (req/s): 0.39 Output token throughput (tok/s): 49.80 Peak output token throughput (tok/s): 120.00 Peak concurrent requests: 11.00 Total token throughput (tok/s): 443.68 ---------------Time to First Token---------------- Mean TTFT (ms): 4771.41 Median TTFT (ms): 5165.51 P99 TTFT (ms): 9396.19 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 121.47 Median TPOT (ms): 118.78 P99 TPOT (ms): 149.75 ---------------Inter-token Latency---------------- Mean ITL (ms): 121.47 Median ITL (ms): 70.34 P99 ITL (ms): 2660.84 ================================================== ``` </details> <details> <summary>Benchmark config</summary> - Model: Qwen/Qwen3-0.6B - Dataset: sonnet (1024 input + 128 output) - Prompts: 100, rate: 10, concurrency: 8, warmups: 5 - Memory fraction: 0.3 (paged path) - Hardware: Apple M1 Pro, 32 GB RAM </details> --------- Signed-off-by: ran <hzz5361@psu.edu> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Builds on the spike in #209. Related: #188.
What
Replace the eager
reshape_and_cache+paged_attention_v2_onlinedispatch with:reshape_and_cacheMetal kernel in the production path.PagedAttentionPrimitivefor attention - a read-only primitive that dispatches the paged attention Metal kernel lazily.Both operations are fully lazy. No per-layer
mx.evalormx.synchronize. The entire 28-layer model builds one lazy graph, evaluated once by the model runner.Bug fix: custom primitives must not call
add_temporaryinsideeval_gpu. MLX'sadd_temporaryremoves buffer pointers from the command encoder's fence tracking, breaking cross-command-buffer synchronization when the graph is evaluated lazily. The fix:from_primitive=trueskips alladd_temporarycalls. MLX's evaluator already manages array lifetimes via the completion handler. This matches the pattern in MLX's officialaxpbyextension example.Why
The original eager path does 1
mx.eval+ 1mx.synchronizeper layer, each a CPU-GPU sync point. With 28 layers, that is 56 sync points per decode step.The primitive path eliminates all of them. The scatter participates in MLX's lazy graph, and the attention primitive dispatches correctly across command buffer boundaries thanks to the
add_temporaryfix.Test
6/6 deterministic golden tests (bit-exact). 362/362 broader suite.
Future Work
reshape_and_cache,metal_unified_attention, etc.)Benchmark (sonnet 1024+128, 100 prompts, concurrency 8, 5 warmups)
Full benchmark output
main (paged attention, eager dispatch)
this PR (MLX scatter + paged attention primitive, fully lazy)
Benchmark config