Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,6 @@ def get_cudagraph_support(
vllm_config: "VllmConfig",
kv_cache_spec: "AttentionSpec",
) -> AttentionCGSupport:
# FA2 does not support CUDA graphs with encoder-decoder models due to
# accuracy issues reported in https://github.com/vllm-project/vllm/issues/33091
if (
vllm_config.model_config.is_encoder_decoder
and get_flash_attn_version() == 2
):
logger.warning_once(
"FlashAttention2 does not support CUDA graphs with "
"encoder-decoder models due to accuracy issues reported in #33091. "
"Disabling CUDA graph."
)
return AttentionCGSupport.NEVER
return cls._cudagraph_support

def __init__(
Expand Down
12 changes: 12 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,12 +1395,14 @@ def _get_encoder_seq_lens(
num_scheduled_tokens: dict[str, int],
kv_cache_spec: KVCacheSpec,
num_reqs: int,
for_cudagraph_capture: bool = False,
) -> tuple[torch.Tensor | None, np.ndarray | None]:
if not isinstance(kv_cache_spec, CrossAttentionSpec):
return None, None

# Zero out buffer for padding requests that are not actually scheduled (CGs)
self.encoder_seq_lens.np[:num_reqs] = 0

# Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch
for req_id in num_scheduled_tokens:
Expand All @@ -1417,6 +1419,15 @@ def _get_encoder_seq_lens(
feature.mm_position.length for feature in req_state.mm_features
)
self.encoder_seq_lens.np[req_index] = encoder_input_tokens
if for_cudagraph_capture:
# During CUDA graph capture, we need to use realistic encoder lengths
# so that max_seqlen_k is captured with the correct value.
max_encoder_len = getattr(
self.model_config.hf_config,
"max_source_positions",
self.max_encoder_len,
)
self.encoder_seq_lens.np[:num_reqs] = max_encoder_len

self.encoder_seq_lens.copy_to_gpu(num_reqs)
encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs]
Expand Down Expand Up @@ -1834,6 +1845,7 @@ def _build_attn_group_metadata(
num_scheduled_tokens or {},
kv_cache_group.kv_cache_spec,
num_reqs_padded,
for_cudagraph_capture=for_cudagraph_capture,
)
if kv_cache_gid > 0:
cm.block_table_tensor = _get_block_table(kv_cache_gid)
Expand Down