Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/models/multimodal/generation/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def run_test(
max_model_len=448,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
# TODO (NickLucche) figure out output differences with non-eager and re-enable
enforce_eager=True,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain a bit why the test needs to set this explicitly?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 There's a subtle difference in output which makes the test fail.


[2025-12-09T18:42:49Z]   -  And the 0-1 pitch on the way to Edgar Martinez. Swung on the line down the left field line for a base hit. Here comes Joy. Here is Junior to third base. They're going to wave him in. The throw to the plate will be late. The Mariners are going to play for the American League Championship. I don't believe it. It just continues. My, oh, my.
--
[2025-12-09T18:42:49Z]   ?                                                                                                                                                                                                                                                                                                                                                    ^
[2025-12-09T18:42:49Z]   +  And the 0-1 pitch on the way to Edgar Martinez. Swung on the line down the left field line for a base hit. Here comes Joy. Here is Junior to third base. They're going to wave him in. The throw to the plate will be late. The Mariners are going to play for the American League Championship. I don't believe it. It just continues. My, oh, my God.

but tbh I think checking token by token a copy-pasted output is too strict.
We can restructure the test to use some new invariant techniques, or just relax it in a separate PR.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can check logprobs instead. For now can you add a code comment with a TODO to fix this later?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

) as vllm_model:
llm = vllm_model.llm

Expand Down
30 changes: 19 additions & 11 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,9 @@ def has_blocked_weights():

default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
self._apply_optimization_level_defaults(default_config)

if (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
):
logger.info(
Expand All @@ -692,22 +693,29 @@ def has_blocked_weights():

if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and self.model_config is not None
):
if self.model_config.pooler_config is not None:
if model_config := self.model_config:
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and model_config.pooler_config is not None
):
logger.warning_once(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config.is_encoder_decoder:
logger.warning_once(
"Encoder-decoder models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
elif (
model_config.is_encoder_decoder
and self.compilation_config.cudagraph_mode
not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
):
logger.info_once(
"Encoder-decoder models do not support %s. "
"Overriding cudagraph_mode to FULL_DECODE_ONLY.",
self.compilation_config.cudagraph_mode.name,
)
self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
Expand Down
11 changes: 10 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,8 @@ def _get_encoder_seq_lens(
if not isinstance(kv_cache_spec, CrossAttentionSpec):
return None, None

# Zero out buffer for padding requests that are not actually scheduled (CGs)
self.encoder_seq_lens.np[:num_reqs] = 0
# Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch
for req_id in num_scheduled_tokens:
Expand Down Expand Up @@ -2764,6 +2766,7 @@ def _determine_batch_execution_and_padding(
# be improved in model runner v2)
force_uniform_decode: bool | None = None,
force_has_lora: bool | None = None,
num_encoder_reqs: int = 0,
) -> tuple[
CUDAGraphMode,
BatchDescriptor,
Expand All @@ -2780,6 +2783,11 @@ def _determine_batch_execution_and_padding(
if force_uniform_decode is None
else force_uniform_decode
)
# Encoder-decoder models only support CG for decoder_step > 0 (no enc_output
# is present). Also, chunked-prefill is disabled, so batch are uniform.
has_encoder_output = (
Comment thread
NickLucche marked this conversation as resolved.
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
)

has_lora = (
len(self.input_batch.lora_id_to_lora_request) > 0
Expand All @@ -2799,7 +2807,7 @@ def _determine_batch_execution_and_padding(
)

cudagraph_mode, batch_descriptor = dispatch_cudagraph(
num_tokens_padded, use_cascade_attn
num_tokens_padded, use_cascade_attn or has_encoder_output
)
num_tokens_padded = batch_descriptor.num_tokens

Expand Down Expand Up @@ -2997,6 +3005,7 @@ def execute_model(
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=max_num_scheduled_tokens,
use_cascade_attn=cascade_attn_prefix_lens is not None,
num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs),
)

logger.debug(
Expand Down