Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ Priority is **1 = highest** (tried first).
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512 | ❌ | ❌ | ❌ | ❌ | All | N/A |
| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256, 512 | ❌ | ❌ | ❌ | ✅ | Decoder | 7.x-9.x |
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2`, `nvfp4` | 16, 32, 64 | 64, 128, 256, 512 | ✅ | ❌ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ✅ | | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ✅ | ✅ | | ✅ | All | ≥10.0 |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ✅ | | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ✅ | ✅ | | ✅ | All | ≥10.0 |
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ✅ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ✅ | ✅ | ❌ | ❌ | Decoder | N/A |
Expand Down
62 changes: 30 additions & 32 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,50 +57,48 @@ def verify_and_update_model_config(model_config: "ModelConfig") -> None:
class Gemma4Config(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"""Force unified attention backend for models with heterogeneous
head dimensions.

Some Gemma4 variants use different head dimensions for
sliding window (head_dim) vs full attention (global_head_dim) layers.
When global_head_dim > 256, FlashAttention rejects those layers
(head_size <= 256 kernel limit), causing vLLM to select a different
backend for each layer type. This mixed-backend execution produces
numerical divergence and output corruption.

The fix detects heterogeneous head dimensions from the model config
and forces TRITON_ATTN (which has no head_size ceiling) for all
layers when the user hasn't explicitly chosen a backend.

TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
require NixlConnector changes to support per-layer KV transfer
with different head dimensions for prefill-decode disaggregation.
"""Configure attention for heterogeneous head dimensions.

Gemma4 uses different head dimensions for sliding window
(head_dim) vs full attention (global_head_dim) layers. The
default FA3 on Hopper cannot handle head_dim > 256, which
causes mixed backend selection and numerical divergence.

When FA4 is available we force it for ALL layers, giving a
uniform kernel path and avoiding the mixed FA3+FA4 penalty.
When FA4 is not available we fall back to Triton.
"""
hf_text_config = vllm_config.model_config.hf_text_config
head_dim = getattr(hf_text_config, "head_dim", None)
global_head_dim = getattr(hf_text_config, "global_head_dim", None)

# Only force Triton when head dimensions actually differ AND the
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
# This avoids unnecessary backend forcing on smaller models where
# the config carries global_head_dim but all layers can still use
# the same FA backend.
max_head_dim = max(head_dim or 0, global_head_dim or 0)
if (
head_dim is not None
and global_head_dim is not None
and head_dim != global_head_dim
and max_head_dim > 256
and vllm_config.attention_config.backend is None
):
if head_dim is None or global_head_dim is None or head_dim == global_head_dim:
return

from vllm.v1.attention.backends.fa_utils import is_fa_version_supported

max_head_dim = max(head_dim, global_head_dim)

if is_fa_version_supported(4) and max_head_dim <= 512:
if vllm_config.attention_config.flash_attn_version is None:
vllm_config.attention_config.flash_attn_version = 4
logger.info(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). Using FA4 for "
"all layers to avoid mixed FA3/FA4 penalty.",
head_dim,
global_head_dim,
)
elif vllm_config.attention_config.backend is None:
from vllm.v1.attention.backends.registry import (
AttentionBackendEnum,
)

vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN
logger.info(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
"backend to prevent mixed-backend numerical divergence.",
"(head_dim=%d, global_head_dim=%d). FA4 not available, "
"forcing TRITON_ATTN backend.",
head_dim,
global_head_dim,
)
Expand Down
108 changes: 107 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
)
return kv_cache_dtype in ["auto", "float16", "bfloat16"]

@classmethod
def supports_mm_prefix(cls) -> bool:
return True

@classmethod
def supports_sink(cls) -> bool:
if not is_flash_attn_varlen_func_available():
Expand Down Expand Up @@ -255,6 +259,19 @@ class FlashAttentionMetadata:

causal: bool = True

# PrefixLM bidirectional ranges for multimodal tokens.
# Shape: (num_seqs, max_ranges, 2) int32, [start, end] per range.
mm_prefix_range_tensor: torch.Tensor | None = None

# Precomputed indices for mm_prefix correction (set by
# gpu_model_runner to avoid CPU-GPU sync in the forward pass).
mm_prefix_indices: torch.Tensor | None = None
mm_prefix_cu_seqlens: torch.Tensor | None = None
mm_prefix_seqlens_k: torch.Tensor | None = None
mm_prefix_bt_indices: torch.Tensor | None = None
mm_prefix_max_seqlen_q: int = 0
mm_prefix_max_seqlen_k: int = 0


def _get_sliding_window_configs(
vllm_config: VllmConfig,
Expand Down Expand Up @@ -793,7 +810,11 @@ def forward(
if self.sliding_window is not None
else None
)
flash_attn_varlen_func(
need_mm_prefix = (
attn_metadata.mm_prefix_range_tensor is not None
and attn_metadata.causal
)
result = flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
Expand All @@ -815,7 +836,26 @@ def forward(
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
return_softmax_lse=need_mm_prefix,
)
if need_mm_prefix:
causal_out, causal_lse = result
_apply_mm_prefix_correction(
causal_out,
causal_lse,
query[:num_actual_tokens],
key_cache,
value_cache,
attn_metadata,
scale=self.scale,
sliding_window=sliding_window_size,
logits_soft_cap=self.logits_soft_cap,
block_table=block_table,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
return output

# Cascade attention (rare case).
Expand Down Expand Up @@ -1051,6 +1091,72 @@ def _forward_encoder_attention(
return output


def _apply_mm_prefix_correction(
causal_out: torch.Tensor,
causal_lse: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
scale: float,
sliding_window: list[int] | None,
logits_soft_cap: float,
block_table: torch.Tensor,
fa_version: int,
q_descale: torch.Tensor | None,
k_descale: torch.Tensor | None,
v_descale: torch.Tensor | None,
) -> None:
"""Correct multimodal token attention from causal to bidirectional.

Uses precomputed indices from FlashAttentionMetadata (populated by
gpu_model_runner before the forward pass) to avoid CPU-GPU sync.

NOTE: The causal call and the non-causal correction call have
overlapping KV ranges for tokens within mm_prefix regions. The
LSE-based merge slightly over-weights keys in the overlap region
compared to the exact (causal OR mm_prefix) mask that the Triton
backend computes in a single kernel. This is a known approximation
until FA4 supports mask_mod with varlen sequences.
Comment on lines +1115 to +1120
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I worry that this can affect model's accuracy through numeric difference from overlayed weights...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's what I explained better on this #42175 (comment)

"""
mm_idx = attn_metadata.mm_prefix_indices
if mm_idx is None or mm_idx.numel() == 0:
return

bt_indices = attn_metadata.mm_prefix_bt_indices

mm_out, mm_lse = flash_attn_varlen_func(
q=query[mm_idx],
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.mm_prefix_cu_seqlens,
max_seqlen_q=attn_metadata.mm_prefix_max_seqlen_q,
seqused_k=attn_metadata.mm_prefix_seqlens_k,
max_seqlen_k=attn_metadata.mm_prefix_max_seqlen_k,
softmax_scale=scale,
causal=False,
window_size=sliding_window,
block_table=(block_table[bt_indices] if block_table is not None else None),
softcap=logits_soft_cap,
fa_version=fa_version,
q_descale=(q_descale[bt_indices] if q_descale is not None else None),
k_descale=(k_descale[bt_indices] if k_descale is not None else None),
v_descale=(v_descale[bt_indices] if v_descale is not None else None),
return_softmax_lse=True,
)

# Fancy indexing creates a copy, so merge into a temporary and scatter back.
merged = causal_out[mm_idx].clone()
merge_attn_states(
merged,
merged.clone(),
causal_lse[:, mm_idx],
mm_out,
mm_lse,
)
causal_out[mm_idx] = merged
Comment thread
lucianommartins marked this conversation as resolved.


def use_cascade_attention(
common_prefix_len: int,
query_lens: np.ndarray,
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/spec_decode/dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def __init__(
@override
def _create_draft_vllm_config(self) -> VllmConfig:
base = super()._create_draft_vllm_config()
# The draft model is text-only — clear the target's multimodal
# flag so flash_attn is not rejected for mm_prefix support.
arch = base.model_config.model_arch_config
if arch.is_mm_prefix_lm:
base.model_config.model_arch_config = replace(arch, is_mm_prefix_lm=False)
return replace(
base,
attention_config=replace(
Expand Down
78 changes: 76 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@
)

if TYPE_CHECKING:
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager
Expand Down Expand Up @@ -6777,6 +6778,9 @@ def _set_mm_prefix_range_for_metadata(
computing mm_prefix_range_tensor once and sharing it across all
metadata objects to avoid redundant host-to-device transfers.
"""
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata,
)
from vllm.v1.attention.backends.triton_attn import (
TritonAttentionMetadata,
)
Expand All @@ -6794,8 +6798,7 @@ def _set_mm_prefix_range_for_metadata(
for metadata in metadata_list:
metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]

# Only compute tensor for TritonAttentionMetadata
if isinstance(metadata, TritonAttentionMetadata):
if isinstance(metadata, (TritonAttentionMetadata, FlashAttentionMetadata)):
if shared_tensor is None:
shared_tensor = (
TritonAttentionMetadata.compute_mm_prefix_range_tensor(
Expand All @@ -6806,6 +6809,77 @@ def _set_mm_prefix_range_for_metadata(
)
metadata.mm_prefix_range_tensor = shared_tensor

# Precompute mm_prefix correction indices for FlashAttention
# on CPU to avoid GPU-tensor .item() calls in the forward pass.
if isinstance(metadata, FlashAttentionMetadata):
self._precompute_mm_prefix_indices(metadata, req_doc_ranges)

def _precompute_mm_prefix_indices(
self,
metadata: "FlashAttentionMetadata",
req_doc_ranges: dict[int, list[tuple[int, int]]],
) -> None:
Comment on lines +6817 to +6821
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too FA specific, which should not be placed at model runner.

"""Precompute mm_prefix correction indices on CPU.

Reads query_start_loc and seq_lens from CPU-side buffers
(no GPU sync) and stores the result as GPU tensors on the
metadata for use in _apply_mm_prefix_correction.
"""
num_reqs = self.input_batch.num_reqs
qsl = self.query_start_loc.np
seq_lens_cpu = self.seq_lens.cpu()

mm_token_indices: list[int] = []
mm_cu_seqlens = [0]
mm_seqlens_k: list[int] = []
mm_bt_indices: list[int] = []

for seq_idx in range(num_reqs):
ranges = req_doc_ranges.get(seq_idx, [])
if not ranges:
continue
q_start = int(qsl[seq_idx])
q_end = int(qsl[seq_idx + 1])
query_len = q_end - q_start
seq_len = int(seq_lens_cpu[seq_idx])
context_len = seq_len - query_len

for r_start, r_end in ranges:
if r_start >= r_end:
continue
tokens = [
q_start + off
for off in range(query_len)
if r_start <= context_len + off <= r_end
]
if tokens:
mm_token_indices.extend(tokens)
mm_cu_seqlens.append(mm_cu_seqlens[-1] + len(tokens))
mm_seqlens_k.append(r_end - r_start + 1)
mm_bt_indices.append(seq_idx)

if not mm_token_indices:
return

device = metadata.seq_lens.device # type: ignore[union-attr]
metadata.mm_prefix_indices = torch.tensor(
mm_token_indices, dtype=torch.long, device=device
)
metadata.mm_prefix_cu_seqlens = torch.tensor(
mm_cu_seqlens, dtype=torch.int32, device=device
)
metadata.mm_prefix_seqlens_k = torch.tensor(
mm_seqlens_k, dtype=torch.int32, device=device
)
metadata.mm_prefix_bt_indices = torch.tensor(
mm_bt_indices, dtype=torch.long, device=device
)
metadata.mm_prefix_max_seqlen_q = max(
mm_cu_seqlens[i + 1] - mm_cu_seqlens[i]
for i in range(len(mm_cu_seqlens) - 1)
)
metadata.mm_prefix_max_seqlen_k = max(mm_seqlens_k)

def may_reinitialize_input_batch(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None:
Expand Down
Loading