Skip to content

Support partitioned Metal attention#181

Merged
ericcurtin merged 1 commit intovllm-project:mainfrom
Kingwl:feat/support-partition-for-attention
Mar 25, 2026
Merged

Support partitioned Metal attention#181
ericcurtin merged 1 commit intovllm-project:mainfrom
Kingwl:feat/support-partition-for-attention

Conversation

@Kingwl
Copy link
Copy Markdown
Contributor

@Kingwl Kingwl commented Mar 19, 2026

This PR supported partitioned v2 Metal attention for long-context workloads. It adds runtime selection between the existing no-partition path and the _ps512 partitioned kernel, allocates the required scratch buffers on the Python side, and dispatches the reduce kernel when partitioning is enabled.

benchmark script: https://gist.github.com/Kingwl/ec70556729956a7e359b64a024820961

Partition Comparison

case                | mode   | shape                            | after_ms | before_ms | delta_ms | speedup
------------------------------------------------------------------------------------------------------------
decode-small        | decode | B=1, q=1, kv=128                 |      0.246 |    0.244 |    0.003 |   -1.1%
decode-typical      | decode | B=8, q=1, kv=2048                |      0.353 |    0.531 |   -0.178 |   33.5%
decode-big-head     | decode | B=8, q=1, kv=2048                |      0.576 |    0.690 |   -0.114 |   16.5%
decode-long         | decode | B=32, q=1, kv=8192               |      0.891 |    1.281 |   -0.390 |   30.4%
decode-long-16384   | decode | B=32, q=1, kv=16384              |      1.771 |    3.418 |   -1.647 |   48.2%
varlen-light        | varlen | 1/128 4/256 16/512 64/1024       |      0.445 |    0.326 |    0.118 |  -36.3%
varlen-typical      | varlen | 32/512 64/1024 128/2048 256/4096 |      3.108 |    2.776 |    0.331 |  -11.9%
varlen-single-long  | varlen | 256/4096                         |      2.429 |    2.136 |    0.292 |  -13.7%
varlen-ragged-longtail | varlen | 1/4096 1/8192 8/512 128/2048     |      0.984 |    0.951 |    0.032 |   -3.4%
varlen-ultra-long   | varlen | 1/8192 1/16384 8/4096 128/16384  |      4.781 |    4.817 |   -0.036 |    0.7%

On the local benchmarks I ran, decode-long (kv=8192) improved by about 30%, and decode-long-16384 improved by about 48%.

num_threads and partition not tuned yet.

Fixed #180

@Kingwl Kingwl force-pushed the feat/support-partition-for-attention branch from 75f3afb to abf0e92 Compare March 19, 2026 14:08
@WindChimeRan
Copy link
Copy Markdown
Collaborator

WindChimeRan commented Mar 19, 2026

Thanks! @Kingwl

Please also pick a dataset you like (e.g., long context dataset), and run vllm bench serve to add end-to-end benchmark results. You may use git checkout main to compare the performance before and after this PR.

See example here: #136 (I should probably add it to the doc somewhere ...)

@WindChimeRan
Copy link
Copy Markdown
Collaborator

@Kingwl Please check the failed unit test. Is it just numerical non-determinism or algorithmic incorrectness?

BTW. Could you please join vllm slack: https://communityinviter.com/apps/vllm-dev/join-vllm-developers-slack ?

We are in hw-metal channel.

I'm in a busy week, but I don't want to block your work. We can have a quick sync on Slack (text or call) if you have any questions or new issues you want to work on.

@Kingwl Kingwl force-pushed the feat/support-partition-for-attention branch from abf0e92 to 76b7ef7 Compare March 19, 2026 16:03
@Kingwl
Copy link
Copy Markdown
Contributor Author

Kingwl commented Mar 19, 2026

Thanks for your comment.

That's fine, we can work asynchronously. I'm also need some time to finger out what should I do(not familiar enough with the codebase).
I have reproduced the failed test case. I'll fix it soon.
And I will find a better dataset and benchmark.

Also I will join the slack later.

@Kingwl Kingwl force-pushed the feat/support-partition-for-attention branch 2 times, most recently from c897f33 to 56e62e1 Compare March 20, 2026 03:15
@Kingwl
Copy link
Copy Markdown
Contributor Author

Kingwl commented Mar 20, 2026

Parameters

device: M4 Max 64G
model:Qwen/Qwen3-0.6B
dataset:learnanything/sharegpt_v3_unfiltered_cleaned_split
serve:VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 vllm serve ... --max-model-len 2048
bench:vllm bench serve --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path /tmp/sharegpt-learnanything/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 100 --request-rate 10 --max-concurrency 32

Current Branch

============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  53.53
Total input tokens:                      23260
Total generated tokens:                  22061
Request throughput (req/s):              1.87
Output token throughput (tok/s):         412.15
Peak output token throughput (tok/s):    639.00
Peak concurrent requests:                36.00
Total token throughput (tok/s):          846.69
---------------Time to First Token----------------
Mean TTFT (ms):                          200.46
Median TTFT (ms):                        166.22
P99 TTFT (ms):                           412.85
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          60.79
Median TPOT (ms):                        57.33
P99 TPOT (ms):                           101.52
---------------Inter-token Latency----------------
Mean ITL (ms):                           54.75
Median ITL (ms):                         49.11
P99 ITL (ms):                            196.28
==================================================

Main branch

============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  50.16
Total input tokens:                      23260
Total generated tokens:                  22061
Request throughput (req/s):              1.99
Output token throughput (tok/s):         439.80
Peak output token throughput (tok/s):    671.00
Peak concurrent requests:                37.00
Total token throughput (tok/s):          903.51
---------------Time to First Token----------------
Mean TTFT (ms):                          174.67
Median TTFT (ms):                        149.87
P99 TTFT (ms):                           349.27
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          56.60
Median TPOT (ms):                        53.60
P99 TPOT (ms):                           99.02
---------------Inter-token Latency----------------
Mean ITL (ms):                           51.22
Median ITL (ms):                         45.26
P99 ITL (ms):                            190.36
==================================================

Diff

metric               current vs main
request throughput   -6.3%
output throughput    -6.3%
total throughput     -6.3%
mean TTFT            +14.8% slower
median TTFT          +10.9% slower
mean TPOT            +7.4% slower
median TPOT          +7.0% slower
mean ITL             +6.9% slower
median ITL           +8.5% slower

On the ShareGPT serve benchmark with the same settings as PR #136, the current branch is consistently slower than main. Throughput is lower by about 6.3%, mean TTFT is 14.8% slower, and TPOT / ITL are also worse by roughly 7% to 8%. So this PR does not show an end-to-end serving win on this workload yet.

@WindChimeRan
Copy link
Copy Markdown
Collaborator

@Kingwl Thanks for sharing.

Maybe it's just not the right dataset? partition is for long context, but sharegpt is short. Could you please try sonnet datasets instead? see #172 description for commands and how to download.

@Kingwl Kingwl force-pushed the feat/support-partition-for-attention branch from 56e62e1 to ea07747 Compare March 20, 2026 12:46
@Kingwl
Copy link
Copy Markdown
Contributor Author

Kingwl commented Mar 21, 2026

This PR is blocked by prefix cache error. I will update after prefix cache issue has been fixed.

@WindChimeRan
Copy link
Copy Markdown
Collaborator

This PR is blocked by prefix cache error. I will update after prefix cache issue has been fixed.

#187 has been merged!

@Kingwl Kingwl force-pushed the feat/support-partition-for-attention branch from ea07747 to bd82f97 Compare March 21, 2026 17:53
@Kingwl
Copy link
Copy Markdown
Contributor Author

Kingwl commented Mar 21, 2026

After rebase main:

Benchmark Script Result

case | mode | shape | current_ms | main_ms | delta_ms | speedup
---------------------------------------------------------------
decode-small            | decode | B=1, q=1, kv=128                 |      0.227 |    0.231 |   -0.004 |    1.8%
decode-typical          | decode | B=8, q=1, kv=2048                |      0.715 |    0.277 |    0.438 | -157.9%
decode-big-head         | decode | B=8, q=1, kv=2048                |      1.030 |    0.560 |    0.469 |  -83.7%
decode-long             | decode | B=32, q=1, kv=8192               |      1.255 |    1.272 |   -0.017 |    1.3%
decode-long-16384       | decode | B=32, q=1, kv=16384              |      1.832 |    3.372 |   -1.539 |   45.7%
varlen-light            | varlen | 1/128 4/256 16/512 64/1024       |      0.462 |    0.313 |    0.150 |  -47.8%
varlen-typical          | varlen | 32/512 64/1024 128/2048 256/4096 |      3.265 |    2.787 |    0.478 |  -17.2%
varlen-single-long      | varlen | 256/4096                         |      2.516 |    2.194 |    0.322 |  -14.7%
varlen-ragged-longtail  | varlen | 1/4096 1/8192 8/512 128/2048     |      1.027 |    0.966 |    0.062 |   -6.4%
varlen-ultra-long       | varlen | 1/8192 1/16384 8/4096 128/16384  |      5.008 |    4.960 |    0.048 |   -1.0%

E2E Benchmark Result

I reran the end-to-end serving benchmark on both the current branch and origin/main with the same sonnet workload, using num-prompts=200.

VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 \
  vllm serve Qwen/Qwen3-0.6B --max-model-len 2048 --host 127.0.0.1 --port 8000

vllm bench serve \
  --backend vllm \
  --base-url http://127.0.0.1:8000 \
  --model Qwen/Qwen3-0.6B \
  --endpoint /v1/completions \
  --dataset-name sonnet \
  --dataset-path ./sonnet.txt \
  --sonnet-input-len 1024 \
  --sonnet-output-len 128 \
  --num-prompts 200 \
  --request-rate 10 \
  --max-concurrency 32

metric current branch origin/main
Benchmark duration 97.66 s 89.52 s
Request throughput 2.05 req/s 2.23 req/s
Output throughput 262.14 tok/s 285.97 tok/s
Total token throughput 2336.00 tok/s 2548.40 tok/s
Mean TTFT 1006.71 ms 1072.37 ms
Median TTFT 782.31 ms 775.66 ms
P99 TTFT 3439.54 ms 3991.04 ms
Mean TPOT 106.13 ms 97.60 ms
Median TPOT 105.55 ms 101.28 ms
P99 TPOT 122.47 ms 105.85 ms

Results:
I ran many times. The result seems stable.
origin/main still has better steady-state throughput and better token generation speed on this workload.
The current branch slightly improves mean TTFT, but that does not translate into better end-to-end throughput.

@WindChimeRan
Copy link
Copy Markdown
Collaborator

@Kingwl Thanks for the new benchmark results!

Your initial result shows improvement on ultra long context

On the local benchmarks I ran, decode-long (kv=8192) improved by about 30%, and decode-long-16384 improved by about 48%.

But the e2e test is only on 1024+128. So I'm a bit hesitant to make any conclusions.

More background: The partition code was orginially vendored from mistral.rs, but later got co-evolved with the varlen + online softmax kernel, but never tested by me. I'm also not sure how upstream vllm kernels (e.g., unified triton, flashatt, flashinfer) handle partitions.

But ultimately, we only have three logical options

  1. integrate the current version partition metal attention, turn on or off automatically or by env var (I guess usually it automatically)
  2. find the real bottleneck of the partition kernel then fix it or redesign it.
  3. deprecate it completely. clean the dead code.

I think the current benchmark results is a good foundation. Could you please provide more guidance on the decision making?

@Kingwl
Copy link
Copy Markdown
Contributor Author

Kingwl commented Mar 21, 2026

I'm trying to run bigger input (2048+128, 4096+128). It's seems obviously slower.
I will append more benchmark If i have some stable result

@Kingwl Kingwl force-pushed the feat/support-partition-for-attention branch from bd82f97 to d1b1f76 Compare March 23, 2026 06:42
@Kingwl
Copy link
Copy Markdown
Contributor Author

Kingwl commented Mar 23, 2026

Use partition kernel for long decode only.

Ran two case for long decode:

For the 1024 + 4096 workload, the current branch is still slower than origin/main across throughput, TTFT, and TPOT, so this decode length is not yet enough to show a net E2E win from partitioning.

For the more decode-heavy 512 + 8192 workload, the current branch starts to pull ahead: overall duration, throughput, and TPOT improve by about 5% to 6%, although TTFT is still worse. This suggests the partition path helps steady-state long decoding, but not first-token latency.

Because of local hardware limits, it is difficult to run much larger-scale benchmarks reliably on this machine.

1024 + 4096 (num-prompts=10, max-concurrency=4, request-rate=10, --ignore-eos)

metric current branch (decode-only, threshold>=4096) origin/main vs main
Benchmark duration (s) 413.10 389.98 slower 5.9%
Output token throughput (tok/s) 99.15 105.03 slower 5.6%
Total token throughput (tok/s) 123.58 130.90 slower 5.6%
Mean TTFT (ms) 567.25 526.10 slower 7.8%
Median TTFT (ms) 640.46 564.25 slower 13.5%
P99 TTFT (ms) 748.08 715.48 slower 4.6%
Mean TPOT (ms) 33.70 31.81 slower 5.9%
Median TPOT (ms) 32.91 31.53 slower 4.4%
P99 TPOT (ms) 35.28 32.79 slower 7.6%

512 + 8192 (num-prompts=5, max-concurrency=2, request-rate=5, --ignore-eos)

metric current branch (decode-only, threshold>=4096) origin/main vs main
Benchmark duration (s) 670.75 714.49 faster 6.1%
Output token throughput (tok/s) 61.07 57.33 faster 6.5%
Total token throughput (tok/s) 64.81 60.85 faster 6.5%
Mean TTFT (ms) 203.95 184.00 slower 10.8%
Median TTFT (ms) 206.21 157.07 slower 31.3%
P99 TTFT (ms) 354.77 321.73 slower 10.3%
Mean TPOT (ms) 28.65 30.42 faster 5.8%
Median TPOT (ms) 29.21 30.98 faster 5.7%
P99 TPOT (ms) 32.25 34.03 faster 5.2%

@ericcurtin
Copy link
Copy Markdown
Collaborator

Test tolerance condition should use PARTITION_THRESHOLD not PARTITION_SIZE, missing sinks buffer binding (latent), duplicate constant definition

@ericcurtin
Copy link
Copy Markdown
Collaborator

conflicts

@Kingwl
Copy link
Copy Markdown
Contributor Author

Kingwl commented Mar 24, 2026

Update soon

@Kingwl Kingwl force-pushed the feat/support-partition-for-attention branch from d1b1f76 to c1c939d Compare March 24, 2026 12:03
Signed-off-by: kingwl <kingwenlu@gmail.com>
@Kingwl Kingwl force-pushed the feat/support-partition-for-attention branch from c1c939d to 3d6d683 Compare March 24, 2026 12:50
@ericcurtin ericcurtin merged commit ff1ce9a into vllm-project:main Mar 25, 2026
5 checks passed
@Kingwl
Copy link
Copy Markdown
Contributor Author

Kingwl commented Mar 25, 2026

Oh thanks. But I'm working on more benchmark about partition.
I'll send new Issue/PR if I have some new things.

ericcurtin pushed a commit that referenced this pull request Apr 6, 2026
…225)

<img width="2493" height="927" alt="bench_primitive_comparison"
src="https://github.com/user-attachments/assets/54a5f038-f16c-4bbe-a96d-343dcfae04fa"
/>




Builds on the spike in #209. Related: #188.

### What

Replace the eager `reshape_and_cache` + `paged_attention_v2_online`
dispatch with:

1. **MLX-native scatter** for cache writes - pure functional,
graph-tracked, donation-eligible. Replaces the custom
`reshape_and_cache` Metal kernel in the production path.
2. **`PagedAttentionPrimitive`** for attention - a read-only primitive
that dispatches the paged attention Metal kernel lazily.

Both operations are fully lazy. No per-layer `mx.eval` or
`mx.synchronize`. The entire 28-layer model builds one lazy graph,
evaluated once by the model runner.

**Bug fix**: custom primitives must not call `add_temporary` inside
`eval_gpu`. MLX's `add_temporary` removes buffer pointers from the
command encoder's fence tracking, breaking cross-command-buffer
synchronization when the graph is evaluated lazily. The fix:
`from_primitive=true` skips all `add_temporary` calls. MLX's evaluator
already manages array lifetimes via the completion handler. This matches
the pattern in MLX's official `axpby` extension example.

### Why

The original eager path does **1 `mx.eval` + 1 `mx.synchronize`** per
layer, each a CPU-GPU sync point. With 28 layers, that is 56 sync points
per decode step.

The primitive path eliminates all of them. The scatter participates in
MLX's lazy graph, and the attention primitive dispatches correctly
across command buffer boundaries thanks to the `add_temporary` fix.


### Test

6/6 deterministic golden tests (bit-exact). 362/362 broader suite.

### Future Work

- Wire partitioned attention (#181) into the primitive path
- Clean up dead eager code paths (`reshape_and_cache`,
`metal_unified_attention`, etc.)


### Benchmark (sonnet 1024+128, 100 prompts, concurrency 8, 5 warmups)

| Metric | main | this PR | Change |
|---|---:|---:|---:|
| Duration (s) | 288.71 | 257.01 | **-11.0%** |
| Output tok/s | 44.34 | 49.80 | **+12.3%** |
| Total tok/s | 394.96 | 443.68 | **+12.3%** |
| Mean TTFT (ms) | 4810.26 | 4771.41 | -0.8% |
| Mean TPOT (ms) | 139.98 | 121.47 | **-13.2%** |
| P99 TPOT (ms) | 169.16 | 149.75 | **-11.5%** |

<details>
<summary>Full benchmark output</summary>

**main (paged attention, eager dispatch)**
```
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             8
Request rate configured (RPS):           10.00
Benchmark duration (s):                  288.71
Total input tokens:                      101230
Total generated tokens:                  12800
Request throughput (req/s):              0.35
Output token throughput (tok/s):         44.34
Peak output token throughput (tok/s):    112.00
Peak concurrent requests:                11.00
Total token throughput (tok/s):          394.96
---------------Time to First Token----------------
Mean TTFT (ms):                          4810.26
Median TTFT (ms):                        4865.42
P99 TTFT (ms):                           10145.45
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          139.98
Median TPOT (ms):                        139.36
P99 TPOT (ms):                           169.16
---------------Inter-token Latency----------------
Mean ITL (ms):                           139.98
Median ITL (ms):                         79.76
P99 ITL (ms):                            2661.48
==================================================
```

**this PR (MLX scatter + paged attention primitive, fully lazy)**
```
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             8
Request rate configured (RPS):           10.00
Benchmark duration (s):                  257.01
Total input tokens:                      101230
Total generated tokens:                  12800
Request throughput (req/s):              0.39
Output token throughput (tok/s):         49.80
Peak output token throughput (tok/s):    120.00
Peak concurrent requests:                11.00
Total token throughput (tok/s):          443.68
---------------Time to First Token----------------
Mean TTFT (ms):                          4771.41
Median TTFT (ms):                        5165.51
P99 TTFT (ms):                           9396.19
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          121.47
Median TPOT (ms):                        118.78
P99 TPOT (ms):                           149.75
---------------Inter-token Latency----------------
Mean ITL (ms):                           121.47
Median ITL (ms):                         70.34
P99 ITL (ms):                            2660.84
==================================================
```

</details>

<details>
<summary>Benchmark config</summary>

- Model: Qwen/Qwen3-0.6B
- Dataset: sonnet (1024 input + 128 output)
- Prompts: 100, rate: 10, concurrency: 8, warmups: 5
- Memory fraction: 0.3 (paged path)
- Hardware: Apple M1 Pro, 32 GB RAM

</details>

---------

Signed-off-by: ran <hzz5361@psu.edu>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Alex-ai-future pushed a commit to Alex-ai-future/vllm-metal that referenced this pull request Apr 8, 2026
…llm-project#225)

<img width="2493" height="927" alt="bench_primitive_comparison"
src="https://github.com/user-attachments/assets/54a5f038-f16c-4bbe-a96d-343dcfae04fa"
/>




Builds on the spike in vllm-project#209. Related: vllm-project#188.

### What

Replace the eager `reshape_and_cache` + `paged_attention_v2_online`
dispatch with:

1. **MLX-native scatter** for cache writes - pure functional,
graph-tracked, donation-eligible. Replaces the custom
`reshape_and_cache` Metal kernel in the production path.
2. **`PagedAttentionPrimitive`** for attention - a read-only primitive
that dispatches the paged attention Metal kernel lazily.

Both operations are fully lazy. No per-layer `mx.eval` or
`mx.synchronize`. The entire 28-layer model builds one lazy graph,
evaluated once by the model runner.

**Bug fix**: custom primitives must not call `add_temporary` inside
`eval_gpu`. MLX's `add_temporary` removes buffer pointers from the
command encoder's fence tracking, breaking cross-command-buffer
synchronization when the graph is evaluated lazily. The fix:
`from_primitive=true` skips all `add_temporary` calls. MLX's evaluator
already manages array lifetimes via the completion handler. This matches
the pattern in MLX's official `axpby` extension example.

### Why

The original eager path does **1 `mx.eval` + 1 `mx.synchronize`** per
layer, each a CPU-GPU sync point. With 28 layers, that is 56 sync points
per decode step.

The primitive path eliminates all of them. The scatter participates in
MLX's lazy graph, and the attention primitive dispatches correctly
across command buffer boundaries thanks to the `add_temporary` fix.


### Test

6/6 deterministic golden tests (bit-exact). 362/362 broader suite.

### Future Work

- Wire partitioned attention (vllm-project#181) into the primitive path
- Clean up dead eager code paths (`reshape_and_cache`,
`metal_unified_attention`, etc.)


### Benchmark (sonnet 1024+128, 100 prompts, concurrency 8, 5 warmups)

| Metric | main | this PR | Change |
|---|---:|---:|---:|
| Duration (s) | 288.71 | 257.01 | **-11.0%** |
| Output tok/s | 44.34 | 49.80 | **+12.3%** |
| Total tok/s | 394.96 | 443.68 | **+12.3%** |
| Mean TTFT (ms) | 4810.26 | 4771.41 | -0.8% |
| Mean TPOT (ms) | 139.98 | 121.47 | **-13.2%** |
| P99 TPOT (ms) | 169.16 | 149.75 | **-11.5%** |

<details>
<summary>Full benchmark output</summary>

**main (paged attention, eager dispatch)**
```
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             8
Request rate configured (RPS):           10.00
Benchmark duration (s):                  288.71
Total input tokens:                      101230
Total generated tokens:                  12800
Request throughput (req/s):              0.35
Output token throughput (tok/s):         44.34
Peak output token throughput (tok/s):    112.00
Peak concurrent requests:                11.00
Total token throughput (tok/s):          394.96
---------------Time to First Token----------------
Mean TTFT (ms):                          4810.26
Median TTFT (ms):                        4865.42
P99 TTFT (ms):                           10145.45
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          139.98
Median TPOT (ms):                        139.36
P99 TPOT (ms):                           169.16
---------------Inter-token Latency----------------
Mean ITL (ms):                           139.98
Median ITL (ms):                         79.76
P99 ITL (ms):                            2661.48
==================================================
```

**this PR (MLX scatter + paged attention primitive, fully lazy)**
```
============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             8
Request rate configured (RPS):           10.00
Benchmark duration (s):                  257.01
Total input tokens:                      101230
Total generated tokens:                  12800
Request throughput (req/s):              0.39
Output token throughput (tok/s):         49.80
Peak output token throughput (tok/s):    120.00
Peak concurrent requests:                11.00
Total token throughput (tok/s):          443.68
---------------Time to First Token----------------
Mean TTFT (ms):                          4771.41
Median TTFT (ms):                        5165.51
P99 TTFT (ms):                           9396.19
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          121.47
Median TPOT (ms):                        118.78
P99 TPOT (ms):                           149.75
---------------Inter-token Latency----------------
Mean ITL (ms):                           121.47
Median ITL (ms):                         70.34
P99 ITL (ms):                            2660.84
==================================================
```

</details>

<details>
<summary>Benchmark config</summary>

- Model: Qwen/Qwen3-0.6B
- Dataset: sonnet (1024 input + 128 output)
- Prompts: 100, rate: 10, concurrency: 8, warmups: 5
- Memory fraction: 0.3 (paged path)
- Hardware: Apple M1 Pro, 32 GB RAM

</details>

---------

Signed-off-by: ran <hzz5361@psu.edu>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Check if kernel partition is working or not

3 participants