From 03539507511f0deb313fe6aadbc9ec775c3c816e Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 3 Dec 2025 18:47:52 +0000 Subject: [PATCH 1/8] init Signed-off-by: NickLucche --- vllm/config/vllm.py | 33 +++++++++++++++--------------- vllm/v1/worker/gpu_model_runner.py | 4 ++++ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 8f27db001330..e849f72d0ad8 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -666,17 +666,17 @@ def has_blocked_weights(): default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level] self._apply_optimization_level_defaults(default_config) - if ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and self.compilation_config.mode != CompilationMode.VLLM_COMPILE - ): - logger.info( - "Cudagraph mode %s is not compatible with compilation mode %s." - "Overriding to NONE.", - self.compilation_config.cudagraph_mode, - self.compilation_config.mode, - ) - self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + # if ( + # self.compilation_config.cudagraph_mode != CUDAGraphMode.PIECEWISE OR FULL_AND_PIECEWISE + # and self.compilation_config.mode != CompilationMode.VLLM_COMPILE + # ): + # logger.info( + # "Cudagraph mode %s is not compatible with compilation mode %s." + # "Overriding to NONE.", + # self.compilation_config.cudagraph_mode, + # self.compilation_config.mode, + # ) + # self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE # async tp is built on top of sequence parallelism # and requires it to be enabled. @@ -703,11 +703,12 @@ def has_blocked_weights(): ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE elif self.model_config.is_encoder_decoder: - logger.warning_once( - "Encoder-decoder models do not support full cudagraphs. " - "Overriding cudagraph_mode to PIECEWISE." - ) - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + pass + # logger.warning_once( + # "Encoder-decoder models do not support full cudagraphs. " + # "Overriding cudagraph_mode to PIECEWISE." + # ) + # self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE # disable cudagraph when enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 39456d2e80ed..df8499eafdf6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3073,6 +3073,7 @@ def execute_model( record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): + # if same_step(prefill and decode_step 0) -> eager o/w call torch.compiled model for subsequent ones model_output = self._model_forward( input_ids=input_ids, positions=positions, @@ -4096,6 +4097,7 @@ def _dummy_run( assert num_tokens_padded <= self.max_num_tokens model_kwargs = self._init_model_kwargs(num_tokens_padded) if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: + # NOT ENC-DEC input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] model_kwargs = { @@ -4153,6 +4155,7 @@ def _dummy_run( ubatch_slices=ubatch_slices_padded, ), ): + # CALLING MODEL outputs = self.model( input_ids=input_ids, positions=positions, @@ -4452,6 +4455,7 @@ def profile_run(self) -> None: self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers + # TORCH COMPILE FOR DECODE (step 1 onward) hidden_states, last_hidden_states = self._dummy_run( self.max_num_tokens, is_profile=True ) From 97eb2c3a8a61b936902123601b84cf7526a68731 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 4 Dec 2025 15:40:59 +0000 Subject: [PATCH 2/8] mind the padding! Signed-off-by: NickLucche --- vllm/v1/worker/gpu_model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index df8499eafdf6..1f44e3855866 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1267,6 +1267,8 @@ def _get_encoder_seq_lens( if not isinstance(kv_cache_spec, CrossAttentionSpec): return None, None + # Zero out buffer for padding requests that are not actually scheduled (CGs) + self.encoder_seq_lens.np[:num_reqs] = 0 # Build encoder_seq_lens array mapping request indices to # encoder lengths for inputs scheduled in this batch for req_id in num_scheduled_tokens: From 548af89800ad622c66b65837dc45521bdd8efb69 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 4 Dec 2025 17:11:09 +0000 Subject: [PATCH 3/8] has_encoder_output check+config changes Signed-off-by: NickLucche --- vllm/config/vllm.py | 43 +++++++++++++++++------------- vllm/v1/worker/gpu_model_runner.py | 10 +++++-- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index e849f72d0ad8..206f5e5c831b 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -666,17 +666,18 @@ def has_blocked_weights(): default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level] self._apply_optimization_level_defaults(default_config) - # if ( - # self.compilation_config.cudagraph_mode != CUDAGraphMode.PIECEWISE OR FULL_AND_PIECEWISE - # and self.compilation_config.mode != CompilationMode.VLLM_COMPILE - # ): - # logger.info( - # "Cudagraph mode %s is not compatible with compilation mode %s." - # "Overriding to NONE.", - # self.compilation_config.cudagraph_mode, - # self.compilation_config.mode, - # ) - # self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + if ( + self.compilation_config.cudagraph_mode + in (CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL_AND_PIECEWISE) + and self.compilation_config.mode != CompilationMode.VLLM_COMPILE + ): + logger.info( + "Cudagraph mode %s is not compatible with compilation mode %s." + "Overriding to NONE.", + self.compilation_config.cudagraph_mode, + self.compilation_config.mode, + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE # async tp is built on top of sequence parallelism # and requires it to be enabled. @@ -702,13 +703,19 @@ def has_blocked_weights(): "Overriding cudagraph_mode to PIECEWISE." ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - elif self.model_config.is_encoder_decoder: - pass - # logger.warning_once( - # "Encoder-decoder models do not support full cudagraphs. " - # "Overriding cudagraph_mode to PIECEWISE." - # ) - # self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + elif ( + self.model_config.is_encoder_decoder + and self.compilation_config.cudagraph_mode + not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY) + ): + logger.info_once( + "Encoder-decoder models do not support %s. " + "Overriding cudagraph_mode to FULL_DECODE_ONLY.", + self.compilation_config.cudagraph_mode.name, + ) + self.compilation_config.cudagraph_mode = ( + CUDAGraphMode.FULL_DECODE_ONLY + ) # disable cudagraph when enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1f44e3855866..ba50547914a0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2766,6 +2766,7 @@ def _determine_batch_execution_and_padding( # be improved in model runner v2) force_uniform_decode: bool | None = None, force_has_lora: bool | None = None, + num_encoder_reqs: int = 0, ) -> tuple[ CUDAGraphMode, BatchDescriptor, @@ -2782,6 +2783,12 @@ def _determine_batch_execution_and_padding( if force_uniform_decode is None else force_uniform_decode ) + # Encoder-decoder models only support GC for decoder_step > 0 (no enc_output + # is present). Also, chunked-prefill is disabled, so batch are uniform. + has_encoder_output = ( + self.model_config.is_encoder_decoder and num_encoder_reqs > 0 + ) + uniform_decode = uniform_decode and not has_encoder_output has_lora = ( len(self.input_batch.lora_id_to_lora_request) > 0 @@ -2999,6 +3006,7 @@ def execute_model( num_scheduled_tokens_np=num_scheduled_tokens_np, max_num_scheduled_tokens=max_num_scheduled_tokens, use_cascade_attn=cascade_attn_prefix_lens is not None, + num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), ) logger.debug( @@ -3075,7 +3083,6 @@ def execute_model( record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): - # if same_step(prefill and decode_step 0) -> eager o/w call torch.compiled model for subsequent ones model_output = self._model_forward( input_ids=input_ids, positions=positions, @@ -4157,7 +4164,6 @@ def _dummy_run( ubatch_slices=ubatch_slices_padded, ), ): - # CALLING MODEL outputs = self.model( input_ids=input_ids, positions=positions, From 21626df977c91f3fa81739a79801f357eceec849 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 4 Dec 2025 17:44:25 +0000 Subject: [PATCH 4/8] cruft Signed-off-by: NickLucche --- vllm/v1/worker/gpu_model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ba50547914a0..594a1c6be8bd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4463,7 +4463,6 @@ def profile_run(self) -> None: self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers - # TORCH COMPILE FOR DECODE (step 1 onward) hidden_states, last_hidden_states = self._dummy_run( self.max_num_tokens, is_profile=True ) From c06589f8b154194c0d92902d371abd846d2b300d Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 5 Dec 2025 17:09:11 +0000 Subject: [PATCH 5/8] address review Signed-off-by: NickLucche --- vllm/config/vllm.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 206f5e5c831b..5d4956bc4123 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -666,9 +666,9 @@ def has_blocked_weights(): default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level] self._apply_optimization_level_defaults(default_config) + if ( - self.compilation_config.cudagraph_mode - in (CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL_AND_PIECEWISE) + self.compilation_config.cudagraph_mode.requires_piecewise_compilation() and self.compilation_config.mode != CompilationMode.VLLM_COMPILE ): logger.info( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 594a1c6be8bd..ca06f048f290 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2783,12 +2783,11 @@ def _determine_batch_execution_and_padding( if force_uniform_decode is None else force_uniform_decode ) - # Encoder-decoder models only support GC for decoder_step > 0 (no enc_output + # Encoder-decoder models only support CG for decoder_step > 0 (no enc_output # is present). Also, chunked-prefill is disabled, so batch are uniform. has_encoder_output = ( self.model_config.is_encoder_decoder and num_encoder_reqs > 0 ) - uniform_decode = uniform_decode and not has_encoder_output has_lora = ( len(self.input_batch.lora_id_to_lora_request) > 0 @@ -2808,7 +2807,7 @@ def _determine_batch_execution_and_padding( ) cudagraph_mode, batch_descriptor = dispatch_cudagraph( - num_tokens_padded, use_cascade_attn + num_tokens_padded, use_cascade_attn or has_encoder_output ) num_tokens_padded = batch_descriptor.num_tokens @@ -4106,7 +4105,6 @@ def _dummy_run( assert num_tokens_padded <= self.max_num_tokens model_kwargs = self._init_model_kwargs(num_tokens_padded) if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: - # NOT ENC-DEC input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] model_kwargs = { From cd13d1f998c875bfafd1cfa280a8b355df6fb352 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 9 Dec 2025 11:17:17 +0000 Subject: [PATCH 6/8] address piecewise mode Signed-off-by: NickLucche --- vllm/config/vllm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 5d4956bc4123..607bb44cddd2 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -693,18 +693,18 @@ def has_blocked_weights(): if current_platform.support_static_graph_mode(): # if cudagraph_mode has full cudagraphs, we need to check support - if ( - self.compilation_config.cudagraph_mode.has_full_cudagraphs() - and self.model_config is not None - ): - if self.model_config.pooler_config is not None: + if model_config := self.model_config: + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and model_config.pooler_config is not None + ): logger.warning_once( "Pooling models do not support full cudagraphs. " "Overriding cudagraph_mode to PIECEWISE." ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE elif ( - self.model_config.is_encoder_decoder + model_config.is_encoder_decoder and self.compilation_config.cudagraph_mode not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY) ): From 427ad52894feeb6cdc2c96a18f43b73baf891876 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 10 Dec 2025 08:45:40 +0000 Subject: [PATCH 7/8] skip cg in tests Signed-off-by: NickLucche --- tests/models/multimodal/generation/test_whisper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index eca2b61e37d5..c95445507976 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -102,6 +102,7 @@ def run_test( max_model_len=448, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, ) as vllm_model: llm = vllm_model.llm From fc2beddcf9b45b93bf6f28e5715ed8dd37ff9ca5 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 10 Dec 2025 10:53:28 +0000 Subject: [PATCH 8/8] todo Signed-off-by: NickLucche --- tests/models/multimodal/generation/test_whisper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index c95445507976..f634a998e8b2 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -102,6 +102,7 @@ def run_test( max_model_len=448, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, + # TODO (NickLucche) figure out output differences with non-eager and re-enable enforce_eager=True, ) as vllm_model: llm = vllm_model.llm