Support partitioned Metal attention#181
Conversation
75f3afb to
abf0e92
Compare
|
Thanks! @Kingwl Please also pick a dataset you like (e.g., long context dataset), and run vllm bench serve to add end-to-end benchmark results. You may use See example here: #136 (I should probably add it to the doc somewhere ...) |
|
@Kingwl Please check the failed unit test. Is it just numerical non-determinism or algorithmic incorrectness? BTW. Could you please join vllm slack: https://communityinviter.com/apps/vllm-dev/join-vllm-developers-slack ? We are in I'm in a busy week, but I don't want to block your work. We can have a quick sync on Slack (text or call) if you have any questions or new issues you want to work on. |
abf0e92 to
76b7ef7
Compare
|
Thanks for your comment. That's fine, we can work asynchronously. I'm also need some time to finger out what should I do(not familiar enough with the codebase). Also I will join the slack later. |
c897f33 to
56e62e1
Compare
Parametersdevice: M4 Max 64G Current BranchMain branchDiffOn the ShareGPT serve benchmark with the same settings as PR #136, the current branch is consistently slower than |
56e62e1 to
ea07747
Compare
|
This PR is blocked by prefix cache error. I will update after prefix cache issue has been fixed. |
#187 has been merged! |
ea07747 to
bd82f97
Compare
|
After rebase main: Benchmark Script ResultE2E Benchmark ResultI reran the end-to-end serving benchmark on both the current branch and origin/main with the same sonnet workload, using num-prompts=200.
Results: |
|
@Kingwl Thanks for the new benchmark results! Your initial result shows improvement on ultra long context
But the e2e test is only on 1024+128. So I'm a bit hesitant to make any conclusions. More background: The partition code was orginially vendored from mistral.rs, but later got co-evolved with the varlen + online softmax kernel, but never tested by me. I'm also not sure how upstream vllm kernels (e.g., unified triton, flashatt, flashinfer) handle partitions. But ultimately, we only have three logical options
I think the current benchmark results is a good foundation. Could you please provide more guidance on the decision making? |
|
I'm trying to run bigger input (2048+128, 4096+128). It's seems obviously slower. |
bd82f97 to
d1b1f76
Compare
|
Use partition kernel for long decode only. Ran two case for long decode: For the For the more decode-heavy Because of local hardware limits, it is difficult to run much larger-scale benchmarks reliably on this machine. 1024 + 4096 (
|
| metric | current branch (decode-only, threshold>=4096) |
origin/main |
vs main |
|---|---|---|---|
| Benchmark duration (s) | 413.10 |
389.98 |
slower 5.9% |
| Output token throughput (tok/s) | 99.15 |
105.03 |
slower 5.6% |
| Total token throughput (tok/s) | 123.58 |
130.90 |
slower 5.6% |
| Mean TTFT (ms) | 567.25 |
526.10 |
slower 7.8% |
| Median TTFT (ms) | 640.46 |
564.25 |
slower 13.5% |
| P99 TTFT (ms) | 748.08 |
715.48 |
slower 4.6% |
| Mean TPOT (ms) | 33.70 |
31.81 |
slower 5.9% |
| Median TPOT (ms) | 32.91 |
31.53 |
slower 4.4% |
| P99 TPOT (ms) | 35.28 |
32.79 |
slower 7.6% |
512 + 8192 (num-prompts=5, max-concurrency=2, request-rate=5, --ignore-eos)
| metric | current branch (decode-only, threshold>=4096) |
origin/main |
vs main |
|---|---|---|---|
| Benchmark duration (s) | 670.75 |
714.49 |
faster 6.1% |
| Output token throughput (tok/s) | 61.07 |
57.33 |
faster 6.5% |
| Total token throughput (tok/s) | 64.81 |
60.85 |
faster 6.5% |
| Mean TTFT (ms) | 203.95 |
184.00 |
slower 10.8% |
| Median TTFT (ms) | 206.21 |
157.07 |
slower 31.3% |
| P99 TTFT (ms) | 354.77 |
321.73 |
slower 10.3% |
| Mean TPOT (ms) | 28.65 |
30.42 |
faster 5.8% |
| Median TPOT (ms) | 29.21 |
30.98 |
faster 5.7% |
| P99 TPOT (ms) | 32.25 |
34.03 |
faster 5.2% |
|
Test tolerance condition should use PARTITION_THRESHOLD not PARTITION_SIZE, missing sinks buffer binding (latent), duplicate constant definition |
|
conflicts |
|
Update soon |
d1b1f76 to
c1c939d
Compare
Signed-off-by: kingwl <kingwenlu@gmail.com>
c1c939d to
3d6d683
Compare
|
Oh thanks. But I'm working on more benchmark about partition. |
…225) <img width="2493" height="927" alt="bench_primitive_comparison" src="https://github.com/user-attachments/assets/54a5f038-f16c-4bbe-a96d-343dcfae04fa" /> Builds on the spike in #209. Related: #188. ### What Replace the eager `reshape_and_cache` + `paged_attention_v2_online` dispatch with: 1. **MLX-native scatter** for cache writes - pure functional, graph-tracked, donation-eligible. Replaces the custom `reshape_and_cache` Metal kernel in the production path. 2. **`PagedAttentionPrimitive`** for attention - a read-only primitive that dispatches the paged attention Metal kernel lazily. Both operations are fully lazy. No per-layer `mx.eval` or `mx.synchronize`. The entire 28-layer model builds one lazy graph, evaluated once by the model runner. **Bug fix**: custom primitives must not call `add_temporary` inside `eval_gpu`. MLX's `add_temporary` removes buffer pointers from the command encoder's fence tracking, breaking cross-command-buffer synchronization when the graph is evaluated lazily. The fix: `from_primitive=true` skips all `add_temporary` calls. MLX's evaluator already manages array lifetimes via the completion handler. This matches the pattern in MLX's official `axpby` extension example. ### Why The original eager path does **1 `mx.eval` + 1 `mx.synchronize`** per layer, each a CPU-GPU sync point. With 28 layers, that is 56 sync points per decode step. The primitive path eliminates all of them. The scatter participates in MLX's lazy graph, and the attention primitive dispatches correctly across command buffer boundaries thanks to the `add_temporary` fix. ### Test 6/6 deterministic golden tests (bit-exact). 362/362 broader suite. ### Future Work - Wire partitioned attention (#181) into the primitive path - Clean up dead eager code paths (`reshape_and_cache`, `metal_unified_attention`, etc.) ### Benchmark (sonnet 1024+128, 100 prompts, concurrency 8, 5 warmups) | Metric | main | this PR | Change | |---|---:|---:|---:| | Duration (s) | 288.71 | 257.01 | **-11.0%** | | Output tok/s | 44.34 | 49.80 | **+12.3%** | | Total tok/s | 394.96 | 443.68 | **+12.3%** | | Mean TTFT (ms) | 4810.26 | 4771.41 | -0.8% | | Mean TPOT (ms) | 139.98 | 121.47 | **-13.2%** | | P99 TPOT (ms) | 169.16 | 149.75 | **-11.5%** | <details> <summary>Full benchmark output</summary> **main (paged attention, eager dispatch)** ``` ============ Serving Benchmark Result ============ Successful requests: 100 Failed requests: 0 Maximum request concurrency: 8 Request rate configured (RPS): 10.00 Benchmark duration (s): 288.71 Total input tokens: 101230 Total generated tokens: 12800 Request throughput (req/s): 0.35 Output token throughput (tok/s): 44.34 Peak output token throughput (tok/s): 112.00 Peak concurrent requests: 11.00 Total token throughput (tok/s): 394.96 ---------------Time to First Token---------------- Mean TTFT (ms): 4810.26 Median TTFT (ms): 4865.42 P99 TTFT (ms): 10145.45 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 139.98 Median TPOT (ms): 139.36 P99 TPOT (ms): 169.16 ---------------Inter-token Latency---------------- Mean ITL (ms): 139.98 Median ITL (ms): 79.76 P99 ITL (ms): 2661.48 ================================================== ``` **this PR (MLX scatter + paged attention primitive, fully lazy)** ``` ============ Serving Benchmark Result ============ Successful requests: 100 Failed requests: 0 Maximum request concurrency: 8 Request rate configured (RPS): 10.00 Benchmark duration (s): 257.01 Total input tokens: 101230 Total generated tokens: 12800 Request throughput (req/s): 0.39 Output token throughput (tok/s): 49.80 Peak output token throughput (tok/s): 120.00 Peak concurrent requests: 11.00 Total token throughput (tok/s): 443.68 ---------------Time to First Token---------------- Mean TTFT (ms): 4771.41 Median TTFT (ms): 5165.51 P99 TTFT (ms): 9396.19 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 121.47 Median TPOT (ms): 118.78 P99 TPOT (ms): 149.75 ---------------Inter-token Latency---------------- Mean ITL (ms): 121.47 Median ITL (ms): 70.34 P99 ITL (ms): 2660.84 ================================================== ``` </details> <details> <summary>Benchmark config</summary> - Model: Qwen/Qwen3-0.6B - Dataset: sonnet (1024 input + 128 output) - Prompts: 100, rate: 10, concurrency: 8, warmups: 5 - Memory fraction: 0.3 (paged path) - Hardware: Apple M1 Pro, 32 GB RAM </details> --------- Signed-off-by: ran <hzz5361@psu.edu> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…llm-project#225) <img width="2493" height="927" alt="bench_primitive_comparison" src="https://github.com/user-attachments/assets/54a5f038-f16c-4bbe-a96d-343dcfae04fa" /> Builds on the spike in vllm-project#209. Related: vllm-project#188. ### What Replace the eager `reshape_and_cache` + `paged_attention_v2_online` dispatch with: 1. **MLX-native scatter** for cache writes - pure functional, graph-tracked, donation-eligible. Replaces the custom `reshape_and_cache` Metal kernel in the production path. 2. **`PagedAttentionPrimitive`** for attention - a read-only primitive that dispatches the paged attention Metal kernel lazily. Both operations are fully lazy. No per-layer `mx.eval` or `mx.synchronize`. The entire 28-layer model builds one lazy graph, evaluated once by the model runner. **Bug fix**: custom primitives must not call `add_temporary` inside `eval_gpu`. MLX's `add_temporary` removes buffer pointers from the command encoder's fence tracking, breaking cross-command-buffer synchronization when the graph is evaluated lazily. The fix: `from_primitive=true` skips all `add_temporary` calls. MLX's evaluator already manages array lifetimes via the completion handler. This matches the pattern in MLX's official `axpby` extension example. ### Why The original eager path does **1 `mx.eval` + 1 `mx.synchronize`** per layer, each a CPU-GPU sync point. With 28 layers, that is 56 sync points per decode step. The primitive path eliminates all of them. The scatter participates in MLX's lazy graph, and the attention primitive dispatches correctly across command buffer boundaries thanks to the `add_temporary` fix. ### Test 6/6 deterministic golden tests (bit-exact). 362/362 broader suite. ### Future Work - Wire partitioned attention (vllm-project#181) into the primitive path - Clean up dead eager code paths (`reshape_and_cache`, `metal_unified_attention`, etc.) ### Benchmark (sonnet 1024+128, 100 prompts, concurrency 8, 5 warmups) | Metric | main | this PR | Change | |---|---:|---:|---:| | Duration (s) | 288.71 | 257.01 | **-11.0%** | | Output tok/s | 44.34 | 49.80 | **+12.3%** | | Total tok/s | 394.96 | 443.68 | **+12.3%** | | Mean TTFT (ms) | 4810.26 | 4771.41 | -0.8% | | Mean TPOT (ms) | 139.98 | 121.47 | **-13.2%** | | P99 TPOT (ms) | 169.16 | 149.75 | **-11.5%** | <details> <summary>Full benchmark output</summary> **main (paged attention, eager dispatch)** ``` ============ Serving Benchmark Result ============ Successful requests: 100 Failed requests: 0 Maximum request concurrency: 8 Request rate configured (RPS): 10.00 Benchmark duration (s): 288.71 Total input tokens: 101230 Total generated tokens: 12800 Request throughput (req/s): 0.35 Output token throughput (tok/s): 44.34 Peak output token throughput (tok/s): 112.00 Peak concurrent requests: 11.00 Total token throughput (tok/s): 394.96 ---------------Time to First Token---------------- Mean TTFT (ms): 4810.26 Median TTFT (ms): 4865.42 P99 TTFT (ms): 10145.45 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 139.98 Median TPOT (ms): 139.36 P99 TPOT (ms): 169.16 ---------------Inter-token Latency---------------- Mean ITL (ms): 139.98 Median ITL (ms): 79.76 P99 ITL (ms): 2661.48 ================================================== ``` **this PR (MLX scatter + paged attention primitive, fully lazy)** ``` ============ Serving Benchmark Result ============ Successful requests: 100 Failed requests: 0 Maximum request concurrency: 8 Request rate configured (RPS): 10.00 Benchmark duration (s): 257.01 Total input tokens: 101230 Total generated tokens: 12800 Request throughput (req/s): 0.39 Output token throughput (tok/s): 49.80 Peak output token throughput (tok/s): 120.00 Peak concurrent requests: 11.00 Total token throughput (tok/s): 443.68 ---------------Time to First Token---------------- Mean TTFT (ms): 4771.41 Median TTFT (ms): 5165.51 P99 TTFT (ms): 9396.19 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 121.47 Median TPOT (ms): 118.78 P99 TPOT (ms): 149.75 ---------------Inter-token Latency---------------- Mean ITL (ms): 121.47 Median ITL (ms): 70.34 P99 ITL (ms): 2660.84 ================================================== ``` </details> <details> <summary>Benchmark config</summary> - Model: Qwen/Qwen3-0.6B - Dataset: sonnet (1024 input + 128 output) - Prompts: 100, rate: 10, concurrency: 8, warmups: 5 - Memory fraction: 0.3 (paged path) - Hardware: Apple M1 Pro, 32 GB RAM </details> --------- Signed-off-by: ran <hzz5361@psu.edu> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This PR supported partitioned v2 Metal attention for long-context workloads. It adds runtime selection between the existing no-partition path and the _ps512 partitioned kernel, allocates the required scratch buffers on the Python side, and dispatches the reduce kernel when partitioning is enabled.
benchmark script: https://gist.github.com/Kingwl/ec70556729956a7e359b64a024820961
On the local benchmarks I ran, decode-long (kv=8192) improved by about 30%, and decode-long-16384 improved by about 48%.
num_threadsandpartitionnot tuned yet.Fixed #180