diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ef9c2676d755..9275725314e4 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -263,18 +263,6 @@ def get_cudagraph_support( vllm_config: "VllmConfig", kv_cache_spec: "AttentionSpec", ) -> AttentionCGSupport: - # FA2 does not support CUDA graphs with encoder-decoder models due to - # accuracy issues reported in https://github.com/vllm-project/vllm/issues/33091 - if ( - vllm_config.model_config.is_encoder_decoder - and get_flash_attn_version() == 2 - ): - logger.warning_once( - "FlashAttention2 does not support CUDA graphs with " - "encoder-decoder models due to accuracy issues reported in #33091. " - "Disabling CUDA graph." - ) - return AttentionCGSupport.NEVER return cls._cudagraph_support def __init__( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1bd6e5116a7d..8ace5e2c383f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1395,12 +1395,14 @@ def _get_encoder_seq_lens( num_scheduled_tokens: dict[str, int], kv_cache_spec: KVCacheSpec, num_reqs: int, + for_cudagraph_capture: bool = False, ) -> tuple[torch.Tensor | None, np.ndarray | None]: if not isinstance(kv_cache_spec, CrossAttentionSpec): return None, None # Zero out buffer for padding requests that are not actually scheduled (CGs) self.encoder_seq_lens.np[:num_reqs] = 0 + # Build encoder_seq_lens array mapping request indices to # encoder lengths for inputs scheduled in this batch for req_id in num_scheduled_tokens: @@ -1417,6 +1419,15 @@ def _get_encoder_seq_lens( feature.mm_position.length for feature in req_state.mm_features ) self.encoder_seq_lens.np[req_index] = encoder_input_tokens + if for_cudagraph_capture: + # During CUDA graph capture, we need to use realistic encoder lengths + # so that max_seqlen_k is captured with the correct value. + max_encoder_len = getattr( + self.model_config.hf_config, + "max_source_positions", + self.max_encoder_len, + ) + self.encoder_seq_lens.np[:num_reqs] = max_encoder_len self.encoder_seq_lens.copy_to_gpu(num_reqs) encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs] @@ -1834,6 +1845,7 @@ def _build_attn_group_metadata( num_scheduled_tokens or {}, kv_cache_group.kv_cache_spec, num_reqs_padded, + for_cudagraph_capture=for_cudagraph_capture, ) if kv_cache_gid > 0: cm.block_table_tensor = _get_block_table(kv_cache_gid)