Skip to content

[Core] Whisper enable FULL_DECODE_ONLY CudaGraph #30072

Merged
vllm-bot merged 8 commits intovllm-project:mainfrom
NickLucche:whisper-gc
Dec 10, 2025
Merged

[Core] Whisper enable FULL_DECODE_ONLY CudaGraph #30072
vllm-bot merged 8 commits intovllm-project:mainfrom
NickLucche:whisper-gc

Conversation

@NickLucche
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche commented Dec 4, 2025

Overview

Whisper currently only supports PIECEWISE cudagraph mode, but it does NOT support torch.compile, so in practice this results in running in eager mode (see profiler below, cudagraphlaunch query).

Pasted image 20251202144814

This PR addresses another important performance limitation by adding support for cuda graph through the FULL_DECODE_ONLY mode.
This guarantees that all decode steps after the first one will just replay the graph.
Mind that the first decode step is actually branching in flow due to encoder_outputs being present for populating the crossattn kv cache. Hence we only GC the steps that follow.

New profile(r) pic:
Pasted image 20251204180234

Full explicit command:

vllm serve openai/whisper-large-v3-turbo --max-num-batched-tokens 32874 -cc.cudagraph_mode=FULL_DECODE_ONLY

New default:

vllm serve openai/whisper-large-v3-turbo
...
INFO 12-04 17:44:48 [vllm.py:704] Encoder-decoder models do not support FULL_AND_PIECEWISE. Overriding cudagraph_mode to FULL_DECODE_ONLY.

Benchmark

Pre

================================================================================
RESULTS SUMMARY
================================================================================
Total samples: 50
Successful: 50
Failed: 0
Total time: 1.36s
Average latency: 0.88s
Throughput: 36.82 requests/s

Post

================================================================================
RESULTS SUMMARY
================================================================================
Total samples: 50
Successful: 50
Failed: 0
Total time: 1.02s
Average latency: 0.61s
Throughput: 49.11 requests/s

cc @ProExpertProg @LucasWilkinson @robertgshaw2-redhat

Comment thread vllm/config/vllm.py Outdated
Comment on lines +649 to +650
self.compilation_config.cudagraph_mode
in (CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL_AND_PIECEWISE)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg need your eyes on this bit

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance enhancement for Whisper models by enabling CUDA graph support through a new FULL_DECODE_ONLY mode. The changes are well-implemented and logically sound. Key modifications include updating the configuration to handle the new CUDA graph mode for encoder-decoder models and adjusting the model runner to correctly distinguish between the initial encoder-involved step and subsequent decode-only steps. This ensures that CUDA graphs are only used for the decode steps, which is the correct approach for models like Whisper. The benchmarks provided demonstrate a substantial improvement in throughput and latency. Overall, this is an excellent contribution that improves the performance of encoder-decoder models in vLLM.

Comment thread vllm/v1/worker/gpu_model_runner.py Outdated
Comment thread vllm/v1/worker/gpu_model_runner.py Outdated
Comment thread vllm/config/vllm.py Outdated
Comment thread vllm/v1/worker/gpu_model_runner.py
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Dec 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@NickLucche
Copy link
Copy Markdown
Collaborator Author

Thanks for reviewing @ProExpertProg @LucasWilkinson , I hope I've addressed your comments

Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good apart from 2 nits!

Comment thread vllm/v1/cudagraph_dispatcher.py Outdated
@@ -145,7 +145,7 @@ def dispatch(
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
use_cascade_attn: bool = False,
piecewise_or_eager_only: bool = False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip_attention or attention_cg_unsupported?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is as-per @LucasWilkinson suggestion

Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually went with disable_full in #30173 to help with line width overflows; can resolve the conflicts depending on which lands first (I think we should go with disable_full since its a bit more terse

Comment thread vllm/config/vllm.py
Comment on lines 681 to +689
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config.is_encoder_decoder:
logger.warning_once(
"Encoder-decoder models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
elif (
self.model_config.is_encoder_decoder
and self.compilation_config.cudagraph_mode
not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if somebody sets mode=piecewise, we don't handle it here, but we should:

            # check support based on model type
            if self.model_config is not None:
                if self.model_config.pooler_config is not None and self.compilation_config.cudagraph_mode.has_full_cudagraphs():
                    ...
                elif (
                    self.model_config.is_encoder_decoder
                    and self.compilation_config.cudagraph_mode
                    not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
                ):

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! Done:

vllm serve openai/whisper-large-v3-turbo -cc.cudagraph_mode=PIECEWISE
...
INFO 12-09 11:15:02 [vllm.py:690] Encoder-decoder models do not support PIECEWISE. Overriding cudagraph_mode to FULL_DECODE_ONLY.

@github-project-automation github-project-automation Bot moved this to In review in NVIDIA Dec 5, 2025
Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; thanks for doing this!

Comment thread vllm/v1/cudagraph_dispatcher.py Outdated
@@ -145,7 +145,7 @@ def dispatch(
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
use_cascade_attn: bool = False,
piecewise_or_eager_only: bool = False,
Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually went with disable_full in #30173 to help with line width overflows; can resolve the conflicts depending on which lands first (I think we should go with disable_full since its a bit more terse

@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 9, 2025
@NickLucche NickLucche enabled auto-merge (squash) December 9, 2025 15:26
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Dec 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the multi-modality Related to multi-modality (#4194) label Dec 10, 2025
@@ -102,6 +102,7 @@ def run_test(
max_model_len=448,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain a bit why the test needs to set this explicitly?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 There's a subtle difference in output which makes the test fail.


[2025-12-09T18:42:49Z]   -  And the 0-1 pitch on the way to Edgar Martinez. Swung on the line down the left field line for a base hit. Here comes Joy. Here is Junior to third base. They're going to wave him in. The throw to the plate will be late. The Mariners are going to play for the American League Championship. I don't believe it. It just continues. My, oh, my.
--
[2025-12-09T18:42:49Z]   ?                                                                                                                                                                                                                                                                                                                                                    ^
[2025-12-09T18:42:49Z]   +  And the 0-1 pitch on the way to Edgar Martinez. Swung on the line down the left field line for a base hit. Here comes Joy. Here is Junior to third base. They're going to wave him in. The throw to the plate will be late. The Mariners are going to play for the American League Championship. I don't believe it. It just continues. My, oh, my God.

but tbh I think checking token by token a copy-pasted output is too strict.
We can restructure the test to use some new invariant techniques, or just relax it in a separate PR.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can check logprobs instead. For now can you add a code comment with a TODO to fix this later?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@vllm-bot vllm-bot merged commit c756fb6 into vllm-project:main Dec 10, 2025
49 of 51 checks passed
@github-project-automation github-project-automation Bot moved this from In review to Done in NVIDIA Dec 10, 2025
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

multi-modality Related to multi-modality (#4194) nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants