Skip to content

[Core] Reduce TTFT with concurrent partial prefills#10235

Merged
comaniac merged 71 commits intovllm-project:mainfrom
opendatahub-io:prefill-slots
Feb 14, 2025
Merged

[Core] Reduce TTFT with concurrent partial prefills#10235
comaniac merged 71 commits intovllm-project:mainfrom
opendatahub-io:prefill-slots

Conversation

@joerunde
Copy link
Collaborator

@joerunde joerunde commented Nov 11, 2024

Replaces #10061, as inspired by @njhill and @comaniac's comments. Co-authored by @prashantgupta24

Context: our customers running large multi-tenanted SaaS deployments of vLLM have a problem where high volumes of small-prompt requests are usually processed smoothly, but quickly pile up in a giant queue when a small number of large-prompt requests are submitted. We see the decoding throughput drop to zero on multiple replicas when this happens.

The current chunked prefill implementation only allows a single sequence to be partially prefilled at a time. This has a few limitations:

  • Multiple medium-sized prompts must wait to be prefilled serially, increasing TTFT for those in the back of the queue
  • A single very large prompt will block all other prompts from prefilling for many iterations. This can eventually starve decoding- for example a 130k token prompt with —max-num-batched-tokens=512 will take about 250 iterations to prefill, in which time the currently decoding sequences may all finish. Send a few of these requests at once and very quickly nothing will be decoding.

This PR implements both

  • An explicit setting for the number of sequences that can be partially prefilled concurrently. This can be configured with --max-num-partial-prefills=N
  • A limit on the number of “very long prompt” sequences that can be prefilled concurrently. This can be configured with
    • --max-long-partial-prefills=N to set the limit on the number of long sequences that can be concurrently prefilled. This defaults to 1 sequence.
    • --long-prefill-threashold=x% to set a percentage of the context length that determines which sequences are considered "long". This defaults to 4%

This is implemented in the v0 scheduler. We’re aware that the v1 implementation is underway and will later become the default, but we need a fix for our customers soon and we hope that what we discover here may help inform a different, better solution in the v1 scheduler.

To test this we created three scenarios, a “medium request” case, a “large request” case, and a “mixed” case.

For the medium request case, we created a subset of the sharegpt dataset with 900 small requests (<50 prompt characters) and 100 of the largest requests (typically between 10k and 20k prompt characters, which we call “medium” sized). We modified the benchmark_serving.py test to not filter out any of the small or large requests, and ran it with this dataset. What we expect to find is similar throughput compared to the main branch, but much lower TTFT on the small requests. Since 10% of the requests are larger than the rest, we should see better TTFT at p90 and below, with comparable TTFT above p90.

For the large request case, we took 990 of the smallest requests from the sharegpt dataset, and then took 10 of the largest requests and duplicated the prompts until they were around 100k characters in length. We ran this in the same way as the medium request case, and here we expect to see smaller TTFT across the board since the small requests will no longer be blocked from prefilling by the few very large requests.

For the mixed case, we used 850 “small”, and 140 “medium” requests, as well as 10 "large" requests where we duplicated the prompts up to 200k characters.

All tests were run on a single 80GB A100, with the command:

python benchmarks/benchmark_serving.py --model meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-path ${test_case} --metric-percentiles 80,85,90,95,99 --request-rate 12

We ran the tests against the main branch (commit 874f551b3626321f6bf9a902b8fd9fc1fa7c7f2e), as well as this PR with the new optimization both disabled (--max-num-partial-prefills=1), and enabled (--max-num-partial-prefills=4)

The results are shown here:
results

The TTFT improvements are very easy to see- in the medium case we cut the p90 TTFT in half, and in the large case we cut it nearly 30x. In both cases we did not measure a throughput drop when run with --max-num-partial-prefills=1, and the throughput drop with --max-num-partial-prefills=4 is minimal.

Surprisingly, along with the massive TTFT improvements in the "mixed" test case, we also see a 4% throughput improvement (3506 tokens/s up from 3368 tokens/s). Based on the fact that ITL still looks a little slower, it seems that the throughput is higher simply because more requests were able to be successfully scheduled at the same time.

cc @rickyyx


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

vllm/config.py Outdated
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
num_prefill_slots: int = 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually "maximum number of prefill sequence in a batch"? If so could we name it something more informative, like max_num_batched_prefill_seqs ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's technically only the number of partial prefills allowed in a batch. You could still have like 100 sequence groups with 5 prompt tokens each all schedule in a single step here.

max_num_partial_prefills?

Comment on lines +406 to +407
# Requests with more than (4% max context length) tokens to prefill
# are "big".
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this definition and threshold?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire goal here is to not allow decode to be starved by the prefill phase blocking on long requests- this part of the PR description:

A single very large prompt will block all other prompts from prefilling for many iterations. This can eventually starve decoding- for example a 130k token prompt with —max-num-batched-tokens=512 will take about 250 iterations to prefill, in which time the currently decoding sequences may all finish. Send a few of these requests at once and very quickly nothing will be decoding.

Just allowing concurrent partial prefills doesn't solve the problem by itself, because multiple long requests could still block up the prefill. So what we do is only allow a single long request to prefill, and allow smaller requests to be pulled from the waiting queue instead of more long ones

@mergify mergify bot added the frontend label Nov 13, 2024
@joerunde joerunde marked this pull request as ready for review November 14, 2024 20:36

@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8])
def test_chunked_prefill_with_actual_engine(model: str,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @rickyyx here's what we tried to do to test that the sampler doesn't throw any assertions- we put multiple prompts into an engine and manually step it forward with them all partially prefilled

@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 14, 2024
@mergify
Copy link

mergify bot commented Nov 20, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @joerunde.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 20, 2024
joerunde and others added 12 commits November 20, 2024 10:02
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
@eladamittai
Copy link

eladamittai commented Mar 2, 2025

Wow, this feature is very cool! Awesome PR! But does it help in case of 2 large requests? Let's say I have --max-num-batched-tokens=256 and 130 running decode requests. I received 2 long prefill requests, let's say 30k input tokens each. Will it divide the remaining budget between the 2 requests to execute them in parallel or will it just take the first request and chunk it as usual while the other one waits in the queue? And if it does divide it between the 2 requests, does it improve the TTFT in this situation?
And, let's say in the meantime, I receive a small request, will it take priority and go full in before the first 2 long requests?

@hibukipanim
Copy link

@joerunde there seem to be some typos regarding --long-prefill-token-threshold.

  • the param description says "Defaults to 4%% of the model's context length.", but the default value is 0:
    long_prefill_token_threshold: Optional[int] = 0
  • in the ArgParser the type is float but in the EngineArgs dataclass it's int

and b.t.w it's not clear to me if it's 4% from the bottom or from the top. e.g. if the model has context-length of 100. 4% is context above 4 or above 96 ? as 4% sounds too small for being a threshold for "long", no?

@schoennenbeck
Copy link
Contributor

schoennenbeck commented Mar 10, 2025

@hibukipanim Both 0 and None don't make sense as a threshold. The logic in the parser takes int(.04 * model_max_len) as the threshold if long_prefill_token_threshold is not set to a positive integer by the user. So the default might be 0 but that is not the value that is then used as the threshold.

It is indeed 4% from the bottom. With modern models this actually makes sense since they often have a max context length of ~100k and 4000 tokens is indeed already quite long (though for most use cases it probably still makes sense to tune this a little and not rely on the default).

lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: Prashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
@madhavaggar
Copy link

madhavaggar commented Apr 9, 2025

  1. With V1, I cannot set max_num_partial_prefills>1. Is this expected behavior, or is this a pending feature?
  2. What is the default behavior while enabling_chunked_prefill=True with vLLM V1?
  3. What settings can I use to optimize this for a large system prompt?
  4. And what is the best way to determine the best settings for these parameters? Is it by using the torch profiler and observing the time for each prefill chunk?

shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: Prashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
@hibukipanim hibukipanim mentioned this pull request May 23, 2025
66 tasks
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Sep 24, 2025
# What this PR does / why we need it?

When processing a mix of large and small requests, the TTFT of responses
is significantly reduc\ed. Please refer to
vllm-project/vllm#10235, which achieves the same
effect by simply limiting the number of prompt fills for long requests.
This solution can be applied to both AscendScheduler (V0) and vLLM
Scheduler (V1). Tests show that TTFT can be significantly improved when
handling such mixed requests. However, This capability is currently
missing when Ascend Scheduler is enabled.

This benchmark used the Qwen3-8B model, with a context length of 128K,
running on a single card.

Regarding dataset selection, the sharegpt_clean dataset is used, with
its content concatenated and cropped. Small requests with token=50 and
medium requests with token=10240 were constructed (there were also large
requests with token=102400, but these were ignored because when using
the Prefill First scheduling strategy, max_num_batched_tokens will not
be set to such a large value). When loading vLLM, set
max_num_batched_tokens=22000. This length can accommodate two
medium-sized requests and some short requests, reflecting an extreme
scenario where the budget is almost entirely occupied by longer
requests.

Next, we mix 990 small requests and 100 medium requests into one type of
load scenario (hereinafter referred to as 10%), and similarly generate
load scenarios with 5% medium requests and 1% load scenarios.

Performance tests were conducted separately for enabling vLLMScheduler,
AscendScheduler, and AscendScheduler (long prompt concurrency set to 1).

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@1dfea5f

---------

Signed-off-by: Csrayz <jover@cmbchina.com>
@hibukipanim hibukipanim mentioned this pull request Oct 8, 2025
22 tasks
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
# What this PR does / why we need it?

When processing a mix of large and small requests, the TTFT of responses
is significantly reduc\ed. Please refer to
vllm-project/vllm#10235, which achieves the same
effect by simply limiting the number of prompt fills for long requests.
This solution can be applied to both AscendScheduler (V0) and vLLM
Scheduler (V1). Tests show that TTFT can be significantly improved when
handling such mixed requests. However, This capability is currently
missing when Ascend Scheduler is enabled.

This benchmark used the Qwen3-8B model, with a context length of 128K,
running on a single card.

Regarding dataset selection, the sharegpt_clean dataset is used, with
its content concatenated and cropped. Small requests with token=50 and
medium requests with token=10240 were constructed (there were also large
requests with token=102400, but these were ignored because when using
the Prefill First scheduling strategy, max_num_batched_tokens will not
be set to such a large value). When loading vLLM, set
max_num_batched_tokens=22000. This length can accommodate two
medium-sized requests and some short requests, reflecting an extreme
scenario where the budget is almost entirely occupied by longer
requests.

Next, we mix 990 small requests and 100 medium requests into one type of
load scenario (hereinafter referred to as 10%), and similarly generate
load scenarios with 5% medium requests and 1% load scenarios.

Performance tests were conducted separately for enabling vLLMScheduler,
AscendScheduler, and AscendScheduler (long prompt concurrency set to 1).

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@1dfea5f

---------

Signed-off-by: Csrayz <jover@cmbchina.com>
luolun pushed a commit to luolun/vllm-ascend that referenced this pull request Nov 19, 2025
# What this PR does / why we need it?

When processing a mix of large and small requests, the TTFT of responses
is significantly reduc\ed. Please refer to
vllm-project/vllm#10235, which achieves the same
effect by simply limiting the number of prompt fills for long requests.
This solution can be applied to both AscendScheduler (V0) and vLLM
Scheduler (V1). Tests show that TTFT can be significantly improved when
handling such mixed requests. However, This capability is currently
missing when Ascend Scheduler is enabled.

This benchmark used the Qwen3-8B model, with a context length of 128K,
running on a single card.

Regarding dataset selection, the sharegpt_clean dataset is used, with
its content concatenated and cropped. Small requests with token=50 and
medium requests with token=10240 were constructed (there were also large
requests with token=102400, but these were ignored because when using
the Prefill First scheduling strategy, max_num_batched_tokens will not
be set to such a large value). When loading vLLM, set
max_num_batched_tokens=22000. This length can accommodate two
medium-sized requests and some short requests, reflecting an extreme
scenario where the budget is almost entirely occupied by longer
requests.

Next, we mix 990 small requests and 100 medium requests into one type of
load scenario (hereinafter referred to as 10%), and similarly generate
load scenarios with 5% medium requests and 1% load scenarios.

Performance tests were conducted separately for enabling vLLMScheduler,
AscendScheduler, and AscendScheduler (long prompt concurrency set to 1).

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@1dfea5f

---------

Signed-off-by: Csrayz <jover@cmbchina.com>
Signed-off-by: luolun <luolun1995@cmbchina.com>
luolun pushed a commit to luolun/vllm-ascend that referenced this pull request Nov 19, 2025
# What this PR does / why we need it?

When processing a mix of large and small requests, the TTFT of responses
is significantly reduc\ed. Please refer to
vllm-project/vllm#10235, which achieves the same
effect by simply limiting the number of prompt fills for long requests.
This solution can be applied to both AscendScheduler (V0) and vLLM
Scheduler (V1). Tests show that TTFT can be significantly improved when
handling such mixed requests. However, This capability is currently
missing when Ascend Scheduler is enabled.

This benchmark used the Qwen3-8B model, with a context length of 128K,
running on a single card.

Regarding dataset selection, the sharegpt_clean dataset is used, with
its content concatenated and cropped. Small requests with token=50 and
medium requests with token=10240 were constructed (there were also large
requests with token=102400, but these were ignored because when using
the Prefill First scheduling strategy, max_num_batched_tokens will not
be set to such a large value). When loading vLLM, set
max_num_batched_tokens=22000. This length can accommodate two
medium-sized requests and some short requests, reflecting an extreme
scenario where the budget is almost entirely occupied by longer
requests.

Next, we mix 990 small requests and 100 medium requests into one type of
load scenario (hereinafter referred to as 10%), and similarly generate
load scenarios with 5% medium requests and 1% load scenarios.

Performance tests were conducted separately for enabling vLLMScheduler,
AscendScheduler, and AscendScheduler (long prompt concurrency set to 1).

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@1dfea5f

---------

Signed-off-by: Csrayz <jover@cmbchina.com>
Signed-off-by: luolun <luolun1995@cmbchina.com>
hwhaokun pushed a commit to hwhaokun/vllm-ascend that referenced this pull request Nov 19, 2025
# What this PR does / why we need it?

When processing a mix of large and small requests, the TTFT of responses
is significantly reduc\ed. Please refer to
vllm-project/vllm#10235, which achieves the same
effect by simply limiting the number of prompt fills for long requests.
This solution can be applied to both AscendScheduler (V0) and vLLM
Scheduler (V1). Tests show that TTFT can be significantly improved when
handling such mixed requests. However, This capability is currently
missing when Ascend Scheduler is enabled.

This benchmark used the Qwen3-8B model, with a context length of 128K,
running on a single card.

Regarding dataset selection, the sharegpt_clean dataset is used, with
its content concatenated and cropped. Small requests with token=50 and
medium requests with token=10240 were constructed (there were also large
requests with token=102400, but these were ignored because when using
the Prefill First scheduling strategy, max_num_batched_tokens will not
be set to such a large value). When loading vLLM, set
max_num_batched_tokens=22000. This length can accommodate two
medium-sized requests and some short requests, reflecting an extreme
scenario where the budget is almost entirely occupied by longer
requests.

Next, we mix 990 small requests and 100 medium requests into one type of
load scenario (hereinafter referred to as 10%), and similarly generate
load scenarios with 5% medium requests and 1% load scenarios.

Performance tests were conducted separately for enabling vLLMScheduler,
AscendScheduler, and AscendScheduler (long prompt concurrency set to 1).

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@1dfea5f

---------

Signed-off-by: Csrayz <jover@cmbchina.com>
Signed-off-by: hwhaokun <haokun0405@163.com>
NSDie pushed a commit to NSDie/vllm-ascend that referenced this pull request Nov 24, 2025
# What this PR does / why we need it?

When processing a mix of large and small requests, the TTFT of responses
is significantly reduc\ed. Please refer to
vllm-project/vllm#10235, which achieves the same
effect by simply limiting the number of prompt fills for long requests.
This solution can be applied to both AscendScheduler (V0) and vLLM
Scheduler (V1). Tests show that TTFT can be significantly improved when
handling such mixed requests. However, This capability is currently
missing when Ascend Scheduler is enabled.

This benchmark used the Qwen3-8B model, with a context length of 128K,
running on a single card.

Regarding dataset selection, the sharegpt_clean dataset is used, with
its content concatenated and cropped. Small requests with token=50 and
medium requests with token=10240 were constructed (there were also large
requests with token=102400, but these were ignored because when using
the Prefill First scheduling strategy, max_num_batched_tokens will not
be set to such a large value). When loading vLLM, set
max_num_batched_tokens=22000. This length can accommodate two
medium-sized requests and some short requests, reflecting an extreme
scenario where the budget is almost entirely occupied by longer
requests.

Next, we mix 990 small requests and 100 medium requests into one type of
load scenario (hereinafter referred to as 10%), and similarly generate
load scenarios with 5% medium requests and 1% load scenarios.

Performance tests were conducted separately for enabling vLLMScheduler,
AscendScheduler, and AscendScheduler (long prompt concurrency set to 1).

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@1dfea5f

---------

Signed-off-by: Csrayz <jover@cmbchina.com>
Signed-off-by: nsdie <yeyifan@huawei.com>
Clorist33 pushed a commit to Clorist33/vllm-ascend that referenced this pull request Dec 9, 2025
# What this PR does / why we need it?

When processing a mix of large and small requests, the TTFT of responses
is significantly reduc\ed. Please refer to
vllm-project/vllm#10235, which achieves the same
effect by simply limiting the number of prompt fills for long requests.
This solution can be applied to both AscendScheduler (V0) and vLLM
Scheduler (V1). Tests show that TTFT can be significantly improved when
handling such mixed requests. However, This capability is currently
missing when Ascend Scheduler is enabled.

This benchmark used the Qwen3-8B model, with a context length of 128K,
running on a single card.

Regarding dataset selection, the sharegpt_clean dataset is used, with
its content concatenated and cropped. Small requests with token=50 and
medium requests with token=10240 were constructed (there were also large
requests with token=102400, but these were ignored because when using
the Prefill First scheduling strategy, max_num_batched_tokens will not
be set to such a large value). When loading vLLM, set
max_num_batched_tokens=22000. This length can accommodate two
medium-sized requests and some short requests, reflecting an extreme
scenario where the budget is almost entirely occupied by longer
requests.

Next, we mix 990 small requests and 100 medium requests into one type of
load scenario (hereinafter referred to as 10%), and similarly generate
load scenarios with 5% medium requests and 1% load scenarios.

Performance tests were conducted separately for enabling vLLMScheduler,
AscendScheduler, and AscendScheduler (long prompt concurrency set to 1).

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@1dfea5f

---------

Signed-off-by: Csrayz <jover@cmbchina.com>
@samuelslinux
Copy link

Not bad

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

frontend ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.