[Core] Whisper enable FULL_DECODE_ONLY CudaGraph #30072
[Core] Whisper enable FULL_DECODE_ONLY CudaGraph #30072vllm-bot merged 8 commits intovllm-project:mainfrom
FULL_DECODE_ONLY CudaGraph #30072Conversation
| self.compilation_config.cudagraph_mode | ||
| in (CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL_AND_PIECEWISE) |
There was a problem hiding this comment.
@ProExpertProg need your eyes on this bit
There was a problem hiding this comment.
Code Review
This pull request introduces a significant performance enhancement for Whisper models by enabling CUDA graph support through a new FULL_DECODE_ONLY mode. The changes are well-implemented and logically sound. Key modifications include updating the configuration to handle the new CUDA graph mode for encoder-decoder models and adjusting the model runner to correctly distinguish between the initial encoder-involved step and subsequent decode-only steps. This ensures that CUDA graphs are only used for the decode steps, which is the correct approach for models like Whisper. The benchmarks provided demonstrate a substantial improvement in throughput and latency. Overall, this is an excellent contribution that improves the performance of encoder-decoder models in vLLM.
|
This pull request has merge conflicts that must be resolved before it can be |
735d14f to
9681358
Compare
|
Thanks for reviewing @ProExpertProg @LucasWilkinson , I hope I've addressed your comments |
ProExpertProg
left a comment
There was a problem hiding this comment.
Looks good apart from 2 nits!
| @@ -145,7 +145,7 @@ def dispatch( | |||
| num_tokens: int, | |||
| uniform_decode: bool, | |||
| has_lora: bool, | |||
| use_cascade_attn: bool = False, | |||
| piecewise_or_eager_only: bool = False, | |||
There was a problem hiding this comment.
skip_attention or attention_cg_unsupported?
There was a problem hiding this comment.
this is as-per @LucasWilkinson suggestion
There was a problem hiding this comment.
actually went with disable_full in #30173 to help with line width overflows; can resolve the conflicts depending on which lands first (I think we should go with disable_full since its a bit more terse
| "Overriding cudagraph_mode to PIECEWISE." | ||
| ) | ||
| self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE | ||
| elif self.model_config.is_encoder_decoder: | ||
| logger.warning_once( | ||
| "Encoder-decoder models do not support full cudagraphs. " | ||
| "Overriding cudagraph_mode to PIECEWISE." | ||
| elif ( | ||
| self.model_config.is_encoder_decoder | ||
| and self.compilation_config.cudagraph_mode | ||
| not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY) | ||
| ): |
There was a problem hiding this comment.
if somebody sets mode=piecewise, we don't handle it here, but we should:
# check support based on model type
if self.model_config is not None:
if self.model_config.pooler_config is not None and self.compilation_config.cudagraph_mode.has_full_cudagraphs():
...
elif (
self.model_config.is_encoder_decoder
and self.compilation_config.cudagraph_mode
not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
):
There was a problem hiding this comment.
Right! Done:
vllm serve openai/whisper-large-v3-turbo -cc.cudagraph_mode=PIECEWISE
...
INFO 12-09 11:15:02 [vllm.py:690] Encoder-decoder models do not support PIECEWISE. Overriding cudagraph_mode to FULL_DECODE_ONLY.
LucasWilkinson
left a comment
There was a problem hiding this comment.
LGTM; thanks for doing this!
| @@ -145,7 +145,7 @@ def dispatch( | |||
| num_tokens: int, | |||
| uniform_decode: bool, | |||
| has_lora: bool, | |||
| use_cascade_attn: bool = False, | |||
| piecewise_or_eager_only: bool = False, | |||
There was a problem hiding this comment.
actually went with disable_full in #30173 to help with line width overflows; can resolve the conflicts depending on which lands first (I think we should go with disable_full since its a bit more terse
|
This pull request has merge conflicts that must be resolved before it can be |
76566f4 to
1b2757f
Compare
| @@ -102,6 +102,7 @@ def run_test( | |||
| max_model_len=448, | |||
| tensor_parallel_size=tensor_parallel_size, | |||
| distributed_executor_backend=distributed_executor_backend, | |||
| enforce_eager=True, | |||
There was a problem hiding this comment.
Could you explain a bit why the test needs to set this explicitly?
There was a problem hiding this comment.
@DarkLight1337 There's a subtle difference in output which makes the test fail.
[2025-12-09T18:42:49Z] - And the 0-1 pitch on the way to Edgar Martinez. Swung on the line down the left field line for a base hit. Here comes Joy. Here is Junior to third base. They're going to wave him in. The throw to the plate will be late. The Mariners are going to play for the American League Championship. I don't believe it. It just continues. My, oh, my.
--
[2025-12-09T18:42:49Z] ? ^
[2025-12-09T18:42:49Z] + And the 0-1 pitch on the way to Edgar Martinez. Swung on the line down the left field line for a base hit. Here comes Joy. Here is Junior to third base. They're going to wave him in. The throw to the plate will be late. The Mariners are going to play for the American League Championship. I don't believe it. It just continues. My, oh, my God.
but tbh I think checking token by token a copy-pasted output is too strict.
We can restructure the test to use some new invariant techniques, or just relax it in a separate PR.
There was a problem hiding this comment.
Yeah we can check logprobs instead. For now can you add a code comment with a TODO to fix this later?
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
aa11276 to
fc2bedd
Compare
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
Overview
Whisper currently only supports PIECEWISE cudagraph mode, but it does NOT support torch.compile, so in practice this results in running in eager mode (see profiler below,
cudagraphlaunchquery).This PR addresses another important performance limitation by adding support for cuda graph through the
FULL_DECODE_ONLYmode.This guarantees that all decode steps after the first one will just replay the graph.
Mind that the first decode step is actually branching in flow due to
encoder_outputsbeing present for populating the crossattn kv cache. Hence we only GC the steps that follow.New profile(r) pic:

Full explicit command:
New default:
Benchmark
Pre
Post
cc @ProExpertProg @LucasWilkinson @robertgshaw2-redhat