Skip to content

[Core] Concurrent partial prefills for V1#31330

Open
ppppqp wants to merge 16 commits intovllm-project:mainfrom
ppppqp:panqp--partial-prefill
Open

[Core] Concurrent partial prefills for V1#31330
ppppqp wants to merge 16 commits intovllm-project:mainfrom
ppppqp:panqp--partial-prefill

Conversation

@ppppqp
Copy link
Copy Markdown
Contributor

@ppppqp ppppqp commented Dec 25, 2025

Purpose

Implements #14003, since we decide to include it in v1.
Referencing the implementation of the original PR: #10235

In short, concurrent partial prefill technique limits the number of large requests from starving small requests, improving throughput and TTFT overall. There are three key parameters for this strategy:

  • max_num_partial_prefills: this controls how many requests that are guaranteed to progress in this run. At the start of each scheduler run, we make speculation and distribute the budget to each partial prefill slots evenly. The budget distribution makes sure that at least max_num_partial_prefills can get new tokens prefilled. If some request does not use all budgets distributed to it, it can happen that there are some budgets left, so more requests can be served.
  • max_long_partial_prefills: this controls how many large requests can we put in the partial prefill slots. For example, if max_num_partial_prefills=4 & max_long_partial_prefills=2, we can have 2 large requests and 2 small requests each run. The default number is 1.
  • long_prefill_token_threshold: this controls the criteria for determination of "large request". If the number of token of the prompt exeeds the threshold, it is considered large request and will be limited by max_long_partial_prefills.

Test Plan

Unit test with parity to the original PR.
I'm not sure about whether we should have parity to this unit test, because based on my testing it seems like the alignment gets abstracted away from scheduler. Would need some help here.
https://github.com/vllm-project/vllm/pull/10235/files#diff-2c6af6e25b8d1074f25ef5ad2901121b30bc1528de74d2b3625636fcb8181624R782-R831

Test Result

Benchmark plan

I followed the setup of the original PR with custom dataset I generated from shareGPT creative writing dataset. The distribution of token count is shown below, showing a three groups of small/medium/large prompts:
image

I tested three versions on the dataset with A40, with the dataset: benchmark-final.jsonl.zip

  1. main branch
  2. This branch with max_num_partial_prefills=1
  3. this branch with max_num_partial_prefills=4 & long_prefill_token_threshold=2048
    For each version, I also tested with output_num=128(which is default) and output_num=1.
vllm serve NousResearch/Hermes-3-Llama-3.1-8B [--max_num_partial_prefills=4 --long_prefill_token_threshold=2048]


vllm bench serve \
  --model NousResearch/Hermes-3-Llama-3.1-8B \
    --dataset-name custom \
  --dataset-path benchmark.jsonl \
  --num-prompts -1 \
  --metric-percentiles 80,85,90,95,99 \
  --request-rate 12 \
  --disable_shuffle \
  [--output_len=1]

Sorry that the chart is probably not organized in the clearest way. Please compare the stats in greyed column together and in white column together.
image

Some interesting observation:

  1. The performance boost on TTFT for output_len=128 experiment group is not significant (~20%). I did some investigation and I think it's because the decoding phase largely averaged out the TTFT since even after the small requests get prefilled, they still need to be in the queue for fairly large number of times, and therefore capped by max_num_seqs which is 128 by default. If we only consider the prefilling phase (i.e. we set output_len to be 1), the TTFT improvement is significant (~400%). I did an extra experiment to further confirm this issue (in the last column of the chart, where max_num_seqs=1024)
  2. As we increase max_num_prefills(1->4->16), we increase the throughput pretty consistently.
  3. There's a tradeoff for TTFT P99. The performance gets consistently worse as we increase the throughput.

Considerations

Some best practice suggestion around this feature:

  1. long_prefill_token_threshold is best used if around max_num_batched_tokens. If it's too high away, the single large request can still starve the queue. If it's lower, then the throughput for large requests get degraded quickly. In this case, if a large request is the only request in queue, it does not get full budget of the run.
  2. max_num_seqs must be tuned up in accordance with the throughput improvement.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the v1 label Dec 25, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements concurrent partial prefills for V1, which is a significant feature for improving throughput and latency. The changes introduce new configuration options and complex scheduling logic. The implementation looks mostly solid, but I've found a critical typo in the configuration validation that would lead to a runtime error.

Comment on lines +119 to +128
class PrefillState:
"""Lightweight state used to reason about a request's prefill status."""

# whether the request in in prefill phase
is_prefill: bool
# number of remaining tokens to prefill
remaining_tokens: int
# whether the prefill is considered a long prefill
is_long_prefill: bool

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a new dataclass here as an abstraction, in case in future we want to implement more sophisticated cocurrency strategy (like the strategy to determine what is a long prefill)

@mergify
Copy link
Copy Markdown

mergify bot commented Dec 25, 2025

Hi @ppppqp, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

request, request.num_computed_tokens
)
if prefill_state.is_prefill:
num_new_tokens = min(num_new_tokens, partial_prefill_slot_budget)
Copy link
Copy Markdown
Contributor Author

@ppppqp ppppqp Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cap the computed num_new_tokens with the allocated budget. This ensures that all requests that are granted a partial prefill slot can progress.

Comment on lines +1105 to +1117
def _is_prefill_with_tokens(request: Request, num_computed_tokens: int) -> bool:
"""Check if the request is in the prefill phase"""
return (
request.num_output_tokens == 0
and num_computed_tokens < request.num_prompt_tokens
)

@staticmethod
def _remaining_prefill_tokens_with_tokens(
request: Request, num_computed_tokens: int
) -> int:
"""Get the number of remaining prefill tokens"""
return max(request.num_prompt_tokens - num_computed_tokens, 0)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about these two helper function. Could you take a look and see if my understanding is correct?

@ppppqp ppppqp force-pushed the panqp--partial-prefill branch from a8b17dc to bb38145 Compare December 25, 2025 07:05
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 25, 2025

Hi @ppppqp, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@ppppqp ppppqp force-pushed the panqp--partial-prefill branch from dc6c074 to e5204f9 Compare December 25, 2025 07:29
@ppppqp ppppqp marked this pull request as draft December 25, 2025 07:30
@chaunceyjiang
Copy link
Copy Markdown
Collaborator

Hi @ppppqp, could you provide some benchmarks, for example under 16k / 32k / 64k ISL scenarios?

@ppppqp ppppqp force-pushed the panqp--partial-prefill branch from 310c3c4 to 57f2517 Compare December 27, 2025 00:24
@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

I believe we already support this functionality in V1. Specifically, the --long_prefill_token_threshold.

@ppppqp
Copy link
Copy Markdown
Contributor Author

ppppqp commented Dec 27, 2025

I believe we already support this functionality in V1. Specifically, the --long_prefill_token_threshold.

Hi! Are you talking about this PR?
If so, I believe it only added the parameter but does not fully resolve the throughput issue. Consider this case:
The budget for each run is 2048, and if we set the threshold to be greater than 2048, the threshold does nothing. If we set threshold to be lower than 2048, say 512, then if we see the requests coming in this order:

Request 1(100k tokens)
Request 2(100k tokens)
Request 3(100k tokens)
Request 4(100k tokens)
Request 5(5 tokens)

The Reuqest 1~4 is still gonna get all budgets for each run (512 * 4), and Request 5 will be starved.

@ppppqp
Copy link
Copy Markdown
Contributor Author

ppppqp commented Dec 27, 2025

Yes agreed - so the updated semantics would then limit the number of long prefills running at once?

Yes, this PR will actually enable max_num_partial_prefills and max_long_partial_prefills that guarantees that at least max_num_partial_prefills - max_long_partial_prefills requests are cleared from starvation.

@ppppqp ppppqp changed the title [WIP][Core] Concurrent partial prefills for V1 [Core] Concurrent partial prefills for V1 Dec 27, 2025
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 27, 2025

Hi @ppppqp, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

1 similar comment
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 27, 2025

Hi @ppppqp, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@ppppqp
Copy link
Copy Markdown
Contributor Author

ppppqp commented Dec 27, 2025

Hi @ppppqp, could you provide some benchmarks, for example under 16k / 32k / 64k ISL scenarios?

@chaunceyjiang
Hi! Do you mean the tuning the input_len to these numbers or I should actually generate a dataset to have the input prompt of that length? I'm not sure about whether the benchmark I've done is good enough or there's still need extra effort needed. Thanks!

@ppppqp ppppqp force-pushed the panqp--partial-prefill branch from fee590d to e3e3e0d Compare December 28, 2025 00:22
@chaunceyjiang
Copy link
Copy Markdown
Collaborator

Hi! Do you mean the tuning the input_len to these numbers or I should actually generate a dataset to have the input prompt of that length? I'm not sure about whether the benchmark I've done is good enough or there's still need extra effort needed. Thanks!

i.e.

vllm bench serve --model XXXX --random-input-len 2048 --random-output-len 1024  --max-concurrency 120 --num-prompts 480 --port  8990

@ppppqp
Copy link
Copy Markdown
Contributor Author

ppppqp commented Jan 1, 2026

Hi! Do you mean the tuning the input_len to these numbers or I should actually generate a dataset to have the input prompt of that length? I'm not sure about whether the benchmark I've done is good enough or there's still need extra effort needed. Thanks!

i.e.

vllm bench serve --model XXXX --random-input-len 2048 --random-output-len 1024  --max-concurrency 120 --num-prompts 480 --port  8990

@chaunceyjiang
Actually, I don't think this makes sense. I tested and if I specify --random-input-len 2048, the lengths of all inputs are 2047, where both cases should perform the same (since it's uniform distribution, partial prefilling should not observe any long request)
Let me know if there's anything else I need to benchmark on🙏.

If you intentionally need this benchmark of uniform distribution, I can also run the benchmark and provide the stats!

ppppqp added 14 commits January 6, 2026 19:36
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
@ppppqp ppppqp force-pushed the panqp--partial-prefill branch from e3e3e0d to 7925874 Compare January 7, 2026 03:37
ppppqp added 2 commits January 6, 2026 19:59
Signed-off-by: Qiping Pan <panqiping@outlook.com>
Signed-off-by: Qiping Pan <panqiping@outlook.com>
@ppppqp
Copy link
Copy Markdown
Contributor Author

ppppqp commented Jan 7, 2026

@chaunceyjiang
Hi! This is the benchmark result using

vllm serve NousResearch/Hermes-3-Llama-3.1-8B [--max_num_partial_prefills=4]

vllm bench serve --model NousResearch/Hermes-3-Llama-3.1-8B --random-input-len 2048 --random-output-len 1024  --max-concurrency 120 --num-prompts 480
  • Comparing main branch and this branch with max_num_partial_prefills=1, the change in this PR has no performance regression on TTFT
  • Comparing main branch and this branch with max_num_partial_prefills=4, The change in this PR causes about 1% slowdown on TTFT for uniformly distributed input.
image

Please let me know if I need to test on --random-input-len 16384/32768/65536. It should take significantly longer to run the benchmark, but happy to do that as well if needed!

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ppppqp.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 12, 2026
@noooop
Copy link
Copy Markdown
Collaborator

noooop commented Jan 24, 2026

I personally really like Concurrent partial prefills and am looking forward to this PR landing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants