[Core][Model] Gemma4: Unified FA4 for all layers + FlashAttention mm_prefix support#42175
[Core][Model] Gemma4: Unified FA4 for all layers + FlashAttention mm_prefix support#42175lucianommartins wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for bidirectional attention within multimodal token ranges for FlashAttention, primarily to support Gemma4 models. It updates the configuration to handle heterogeneous head dimensions and implements a correction mechanism in the forward pass to merge causal and bidirectional attention results. However, the current implementation has significant performance bottlenecks due to synchronous CPU-GPU transfers and nested loops in the forward pass. Additionally, the decomposition logic for merging attention states is mathematically incorrect, leading to double-counting of tokens and incorrect KV range indexing.
|
Hi @lucianommartins, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
fyi @ywang96 @Isotr0py - a known limitation: mm_prefix attention overlap approximation The two-call decomposition for PrefixLM bidirectional attention produces an approximation, not an exact result. The causal call covers KV positions [0, p] for a query at position p, while the non-causal correction covers [r_start, r_end]. For query tokens inside an mm_prefix range where p > r_start, these ranges overlap in [r_start, p]. The LSE-based merge ( Impact scope:
Correct fix: FA4's |
b12a2b7 to
042d5ab
Compare
|
Documentation preview: https://vllm--42175.org.readthedocs.build/en/42175/ |
|
Hi @lucianommartins, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
| def _precompute_mm_prefix_indices( | ||
| self, | ||
| metadata: "FlashAttentionMetadata", # type: ignore[name-defined] | ||
| req_doc_ranges: dict[int, list[tuple[int, int]]], | ||
| ) -> None: |
There was a problem hiding this comment.
This is too FA specific, which should not be placed at model runner.
BTW, which |
| NOTE: The causal call and the non-causal correction call have | ||
| overlapping KV ranges for tokens within mm_prefix regions. The | ||
| LSE-based merge slightly over-weights keys in the overlap region | ||
| compared to the exact (causal OR mm_prefix) mask that the Triton | ||
| backend computes in a single kernel. This is a known approximation | ||
| until FA4 supports mask_mod with varlen sequences. |
There was a problem hiding this comment.
Hmmm, I worry that this can affect model's accuracy through numeric difference from overlayed weights...
There was a problem hiding this comment.
yes, that's what I explained better on this #42175 (comment)
hey @Isotr0py - it is the |
But I think the latest FA4 entrypoint should have removed this limitation: https://github.com/Dao-AILab/flash-attention/blob/ab66326aaa4fe3529fbc00f3156f3a762dd3141b/flash_attn/cute/interface.py#L588-L614 Perhaps we should update our FA fork? cc @LucasWilkinson |
|
I was under the impression that FA4 does not yet support headdim 512 on blackwell: Dao-AILab/flash-attention#2456 SM90 support was landed with vllm-project/flash-attention#130, presumably that's what you used for these benchmarks? edit: ah yeah, I see they were run on H100. You'll need to update this PR so that it only attempts FA4 for this head size on SM90 |
hey @Isotr0py @LucasWilkinson - have you folks had a chance to look into it? |
|
@MatthewBonanni Can we have Dao-AILab/flash-attention#2224 in our FA fork? I think this PR needs this upstream sync to allow proper bidirectional attention mask computation for Gemma4. |
|
have you had a chance to take a look at it @MatthewBonanni @LucasWilkinson ? |
|
This pull request has merge conflicts that must be resolved before it can be |
|
I've been kinda surprised by gemma 4 family's relatively poor perf on vllm (as compared to qwen, gpt-oss etc), even with tuning |
|
hey @Isotr0py @MatthewBonanni @LucasWilkinson - have you had a chance to ingest Dao-AILab/flash-attention#2224 into vLLM? |
042d5ab to
ee7907c
Compare
|
Hi @lucianommartins, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
ee7907c to
7045eaa
Compare
|
any news here @Isotr0py @MatthewBonanni @LucasWilkinson? it is really important to the Gemma4 devx for the vLLM community |
|
^^ :) |
|
@lucianommartins I've made a sync PR (vllm-project/flash-attention#141) which captures that commit. vLLM-side is #44065. I'll work on getting it landed in the next couple of days |
- Relax max_head_dim threshold from 256→512 on Hopper/Blackwell
where FA4 is available; preserve TRITON_ATTN fallback elsewhere
- Force FA4 (flash_attn_version=4) for ALL Gemma4 layers when FA4
is available, avoiding the mixed FA3+FA4 penalty that caused ~8%
decode throughput regression vs uniform FA4
- Add mm_prefix (PrefixLM bidirectional attention) support to the
FlashAttention backend via two-call decomposition:
- First call: standard causal attention over full sequence
- Second call: non-causal attention restricted to mm_prefix ranges
- Merge via log-sum-exp rescaling (merge_attn_states)
- Zero overhead for text-only batches; ~3% per-step cost for
multimodal batches (kernel launch overhead only)
- Add mm_prefix_range_tensor field to FlashAttentionMetadata and
wire it through gpu_model_runner._set_mm_prefix_range_for_metadata
- Enables FLASH_ATTN backend for all PrefixLM models (Gemma3,
Gemma4, PaliGemma, Molmo2) that were previously restricted to
Triton or FlexAttention
Benchmark results (H100 SXM, Gemma4 family, FA4 vs Triton baseline):
- Prefill throughput: +25% to +40% across E2B/E4B/26B-A4B/31B
- Long context (8K): +27% to +41% throughput
- Very long context: +36% to +70% throughput (E2B/E4B/26B-A4B)
- Prefill TTFT latency: -22% to -31% (faster time-to-first-token)
- Long ctx decode lat: -16% to -27% (faster per-step at 8K context)
- Short ctx decode lat: neutral (~±2%, weight-loading dominated)
Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
|
waiting for #44065 to be merged 🚀 |
7045eaa to
9161dcc
Compare
|
@lucianommartins #44065 just landed 👍 |
Purpose
Enable Flash Attention 4 (FA4) as the default attention backend for Gemma4 models on Hopper (SM90) and Blackwell (SM100+) GPUs, and add multimodal bidirectional attention (mm_prefix / PrefixLM) support to the FlashAttention backend.
Gemma4 uses heterogeneous head dimensions across layers:
head_dim=256for sliding-window attention andglobal_head_dim=512for full attention. Previously, theGemma4Configgate detected this mismatch and forcedTRITON_ATTNas the only backend that could handle both sizes. With FA4 supportinghead_dimup to 512, this restriction is no longer necessary on FA4-capable hardware.Problem 1: Mixed FA3+FA4 penalty. When FLASH_ATTN was manually selected, the per-layer FA version dispatch assigned FA3 (the Hopper default) to sliding layers (
head_dim=256) and FA4 to full-attention layers (head_dim=512). Benchmarking showed this mixed execution is ~8% slower than uniform FA4 for all layers, because FA4 has benchmarked tile configurations forhead_dim≤256that perform comparably to FA3.Problem 2: mm_prefix not supported by FlashAttention. Gemma4 (and Gemma3, PaliGemma, Molmo2, etc.) use bidirectional attention for multimodal tokens (
use_bidirectional_attention="vision"). TheFlashAttentionBackend.supports_mm_prefix()returnedFalse, forcing these models to Triton or FlexAttention. This blocked FA4 from being used at longer context lengths where the multimodal validation activates.Changes
vllm/model_executor/models/config.py- Gemma4Config:max_head_dim ≤ 512: setflash_attn_version=4for all layers (uniform FA4, no mixed FA3+FA4)TRITON_ATTN(preserves existing safety behavior)flash_attn_versionoverridevllm/v1/attention/backends/flash_attn.py- FlashAttention backend:supports_mm_prefix() → Truemm_prefix_range_tensorfield toFlashAttentionMetadata_apply_mm_prefix_correction():flash_attn_varlen_funccall (produces correct results for text tokens)merge_attn_statesusing LSE rescalingmm_prefix_range_tensor is not None)vllm/v1/worker/gpu_model_runner.py:_set_mm_prefix_range_for_metadata()to compute and setmm_prefix_range_tensorforFlashAttentionMetadataalongside the existingTritonAttentionMetadatahandlingImpact on other models
mm_prefix_range_tensoris never set, so the correction path is never entered.Gemma4Configchanges only activate forGemma4ForCausalLMandGemma4ForConditionalGenerationarchitectures viaMODELS_CONFIG_MAP.Test Plan
Test Result
Benchmarked on 8×H100 SXM 80GB with TP=2 (31B) and TP=1 (E2B/E4B). All Gemma4 model sizes tested.
FA4 vs Triton — Throughput (tokens/sec, higher is better)
FA4 vs Triton — Latency (P50, lower is better)
Key findings
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.