Skip to content

[Continuous Batching] Unified prefilling & decoding prototype#172

Merged
ericcurtin merged 13 commits intovllm-project:mainfrom
WindChimeRan:continuous_batching_input
Mar 20, 2026
Merged

[Continuous Batching] Unified prefilling & decoding prototype#172
ericcurtin merged 13 commits intovllm-project:mainfrom
WindChimeRan:continuous_batching_input

Conversation

@WindChimeRan
Copy link
Copy Markdown
Collaborator

@WindChimeRan WindChimeRan commented Mar 18, 2026

bench_cb_comparison

Summary

Did not touch varlen kernel. Just wire it to Unified prefilling & decoding

  1. Unified chunked prefilling and decoding into a single forward pass. All request types (fresh prefill, continuation chunk, decode) go through one model() call per scheduler step via prepare_unified() + the varlen paged attention kernel
  2. Upgraded model runner from v0-style (separate prefill/decode phases) to v1-style (single collection loop → one unified forward → per-request post-processing)
  3. Fixed chunked prefill from O(n²) to O(n) — continuation chunks now process only the delta tokens with start_pos offset, instead of re-computing from position 0

Post-processing 4 arms after the unified forward pass:

  • Decode: sample from position [0..num_decode-1], append token to existing state, increment generated_tokens
  • New complete prefill: sample from last position of the segment, create a new RequestState with generated_tokens=1
  • Intermediate chunk (new or cached): sample is discarded — the forward pass only populates the KV cache for this chunk
  • Cached last chunk: sample from last position, append token to existing RequestState, transition request from prefill to decode phase

Depreciation of vllm v0-style model runner:

Deleted File Why
_batched_decode_paged() model_runner.py Separate decode forward pass — replaced by unified
_prefill_packed_paged() model_runner.py Separate prefill forward pass — replaced by unified
_run_packed_prefill() model_runner.py Batch-splitting wrapper around _prefill_packed_paged
prepare_decode() paged_attention_common.py Built decode-only context — replaced by prepare_unified()
prepare_prefill_packed() paged_attention_common.py Built prefill-only context — replaced by prepare_unified()
_metal_kernel_decode_attention() paged_attention.py Decode-specific attention dispatch — unified path uses _metal_kernel_prefill_attention() for everything

All six were v0-style "phase-separated" functions. They're replaced by two v1-style unified functions: prepare_unified() and _unified_prefill_decode_paged().

Benchmark

I use sonnet dataset this time, with 1024 input and 128 output. This will make the "before this PR" 's O(n²) problem standout.

wget https://raw.githubusercontent.com/vllm-project/vllm/main/benchmarks/sonnet.txt

# Paged path server:
VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 \
    vllm serve Qwen/Qwen3-0.6B --max-model-len 2048

# mlx_lm path server
vllm serve Qwen/Qwen3-0.6B --max-model-len 2048



# Client:
vllm bench serve --backend vllm --model Qwen/Qwen3-0.6B \
    --endpoint /v1/completions \
    --dataset-name sonnet \
    --dataset-path ../vllm/benchmarks/sonnet.txt \
    --sonnet-input-len 1024 \
    --sonnet-output-len 128 \
    --num-prompts 100 \
    --request-rate 10 \
    --max-concurrency 32

Results

Full benchmark output

This PR (Paged, Continuous Batching, vllm v1):

============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  185.35
Total input tokens:                      101230
Total generated tokens:                  11828
Request throughput (req/s):              0.54
Output token throughput (tok/s):         63.81
Peak output token throughput (tok/s):    231.00
Peak concurrent requests:                35.00
Total token throughput (tok/s):          609.96
---------------Time to First Token----------------
Mean TTFT (ms):                          11537.56
Median TTFT (ms):                        7359.41
P99 TTFT (ms):                           38604.55
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          406.60
Median TPOT (ms):                        409.58
P99 TPOT (ms):                           762.81
---------------Inter-token Latency----------------
Mean ITL (ms):                           395.03
Median ITL (ms):                         143.64
P99 ITL (ms):                            3664.41
==================================================

Before this PR (Paged, vllm v0):

============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  323.08
Total input tokens:                      101230
Total generated tokens:                  12800
Request throughput (req/s):              0.31
Output token throughput (tok/s):         39.62
Peak output token throughput (tok/s):    192.00
Peak concurrent requests:                35.00
Total token throughput (tok/s):          352.95
---------------Time to First Token----------------
Mean TTFT (ms):                          17453.35
Median TTFT (ms):                        11976.57
P99 TTFT (ms):                           59035.62
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          652.37
Median TPOT (ms):                        600.09
P99 TPOT (ms):                           1033.64
---------------Inter-token Latency----------------
Mean ITL (ms):                           652.37
Median ITL (ms):                         182.30
P99 ITL (ms):                            5758.13
==================================================

mlx lm path (None Paged KV):

============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Maximum request concurrency:             32
Request rate configured (RPS):           10.00
Benchmark duration (s):                  346.86
Total input tokens:                      101230
Total generated tokens:                  12800
Request throughput (req/s):              0.29
Output token throughput (tok/s):         36.90
Peak output token throughput (tok/s):    49.00
Peak concurrent requests:                33.00
Total token throughput (tok/s):          328.75
---------------Time to First Token----------------
Mean TTFT (ms):                          87405.18
Median TTFT (ms):                        104122.81
P99 TTFT (ms):                           105163.68
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          50.72
Median TPOT (ms):                        51.05
P99 TPOT (ms):                           54.25
---------------Inter-token Latency----------------
Mean ITL (ms):                           50.72
Median ITL (ms):                         47.37
P99 ITL (ms):                            82.28
==================================================

Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan force-pushed the continuous_batching_input branch from 6d6ddb2 to 528cb3a Compare March 19, 2026 21:34
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan force-pushed the continuous_batching_input branch from 528cb3a to 7cd4cb6 Compare March 19, 2026 21:38
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Resolve conflict in model_runner.py: keep STTExecutor class (from vllm-project#173),
drop stale MAX_PACKED_PREFILL_TOKENS constant (removed in this branch).

Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan WindChimeRan marked this pull request as ready for review March 20, 2026 08:23
Signed-off-by: ran <hzz5361@psu.edu>
@ericcurtin ericcurtin merged commit 41187bc into vllm-project:main Mar 20, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants