Skip to content

[Core][Model] Gemma4: Unified FA4 for all layers + FlashAttention mm_prefix support#42175

Open
lucianommartins wants to merge 1 commit into
vllm-project:mainfrom
lucianommartins:lucianommartins/gemma4-fa4
Open

[Core][Model] Gemma4: Unified FA4 for all layers + FlashAttention mm_prefix support#42175
lucianommartins wants to merge 1 commit into
vllm-project:mainfrom
lucianommartins:lucianommartins/gemma4-fa4

Conversation

@lucianommartins
Copy link
Copy Markdown
Contributor

Purpose

Enable Flash Attention 4 (FA4) as the default attention backend for Gemma4 models on Hopper (SM90) and Blackwell (SM100+) GPUs, and add multimodal bidirectional attention (mm_prefix / PrefixLM) support to the FlashAttention backend.

Gemma4 uses heterogeneous head dimensions across layers: head_dim=256 for sliding-window attention and global_head_dim=512 for full attention. Previously, the Gemma4Config gate detected this mismatch and forced TRITON_ATTN as the only backend that could handle both sizes. With FA4 supporting head_dim up to 512, this restriction is no longer necessary on FA4-capable hardware.

Problem 1: Mixed FA3+FA4 penalty. When FLASH_ATTN was manually selected, the per-layer FA version dispatch assigned FA3 (the Hopper default) to sliding layers (head_dim=256) and FA4 to full-attention layers (head_dim=512). Benchmarking showed this mixed execution is ~8% slower than uniform FA4 for all layers, because FA4 has benchmarked tile configurations for head_dim≤256 that perform comparably to FA3.

Problem 2: mm_prefix not supported by FlashAttention. Gemma4 (and Gemma3, PaliGemma, Molmo2, etc.) use bidirectional attention for multimodal tokens (use_bidirectional_attention="vision"). The FlashAttentionBackend.supports_mm_prefix() returned False, forcing these models to Triton or FlexAttention. This blocked FA4 from being used at longer context lengths where the multimodal validation activates.

Changes

vllm/model_executor/models/config.py - Gemma4Config:

  • When FA4 is available and max_head_dim ≤ 512: set flash_attn_version=4 for all layers (uniform FA4, no mixed FA3+FA4)
  • When FA4 is not available: fall back to TRITON_ATTN (preserves existing safety behavior)
  • Respects user-explicit flash_attn_version override

vllm/v1/attention/backends/flash_attn.py - FlashAttention backend:

  • Add supports_mm_prefix() → True
  • Add mm_prefix_range_tensor field to FlashAttentionMetadata
  • Implement two-call decomposition for mm_prefix correction in _apply_mm_prefix_correction():
    1. Main causal flash_attn_varlen_func call (produces correct results for text tokens)
    2. Non-causal call restricted to mm_prefix ranges (corrects multimodal tokens)
    3. Merge via merge_attn_states using LSE rescaling
  • Zero overhead for text-only batches (the correction is gated on mm_prefix_range_tensor is not None)

vllm/v1/worker/gpu_model_runner.py:

  • Extend _set_mm_prefix_range_for_metadata() to compute and set mm_prefix_range_tensor for FlashAttentionMetadata alongside the existing TritonAttentionMetadata handling

Impact on other models

  • PrefixLM models (Gemma3, PaliGemma, Molmo2, Moondream3, Bagel): previously restricted to Triton or FlexAttention for mm_prefix. These models can now use FLASH_ATTN as a backend candidate. The two-call decomposition is model-agnostic and mathematically correct for any PrefixLM model
  • Non-PrefixLM models: no change. mm_prefix_range_tensor is never set, so the correction path is never entered.
  • Non-Gemma4 models: the Gemma4Config changes only activate for Gemma4ForCausalLM and Gemma4ForConditionalGeneration architectures via MODELS_CONFIG_MAP.

Test Plan

# 1. Verify FA4 auto-selection for Gemma4
python -c "
from vllm.config import VllmConfig, ModelConfig
vc = VllmConfig(model_config=ModelConfig(model='google/gemma-4-31B-it', trust_remote_code=True, max_model_len=8192))
assert vc.attention_config.flash_attn_version == 4
assert vc.attention_config.backend is None  # auto-selects FLASH_ATTN
print('PASS: FA4 auto-selected')
"

# 2. Verify mm_prefix support (previously failed with 'partial multimodal token full attention not supported')
CUDA_VISIBLE_DEVICES=0,1 vllm bench throughput \
    --model google/gemma-4-31B-it --attention-backend FLASH_ATTN \
    --dataset-name random --random-input-len 8192 --random-output-len 128 \
    --num-prompts 5 --tensor-parallel-size 2 --max-model-len 16384 \
    --gpu-memory-utilization 0.90 --num-gpu-blocks-override 8000 \
    --trust-remote-code --dtype bfloat16

# 3. Verify Triton fallback on non-FA4 platforms (config gate)
python -c "
from unittest.mock import patch
from vllm.config import VllmConfig, ModelConfig
with patch('vllm.v1.attention.backends.fa_utils.is_fa_version_supported', return_value=False):
    vc = VllmConfig(model_config=ModelConfig(model='google/gemma-4-31B-it', trust_remote_code=True, max_model_len=8192))
    assert str(vc.attention_config.backend) == 'AttentionBackendEnum.TRITON_ATTN'
    print('PASS: Triton fallback when FA4 unavailable')
"

# 4. Throughput benchmark: FA4 vs Triton baseline
CUDA_VISIBLE_DEVICES=0,1 vllm bench throughput \
    --model google/gemma-4-31B-it --attention-backend FLASH_ATTN \
    --dataset-name random --random-input-len 4096 --random-output-len 16 \
    --num-prompts 100 --tensor-parallel-size 2 --max-model-len 8192 \
    --trust-remote-code --dtype bfloat16

Test Result

Benchmarked on 8×H100 SXM 80GB with TP=2 (31B) and TP=1 (E2B/E4B). All Gemma4 model sizes tested.

FA4 vs Triton — Throughput (tokens/sec, higher is better)

Benchmark E2B (~2B) E4B (~4B) 26B-A4B (MoE) 31B
Prefill (4K input) +40% +24% +25% +33%
Long context (8K) +41% +28% +27% +28%
Very long context (15-16K) +70% +47% +36%
Mixed (1K/1K) +5% +5% +1% +4%
High batch (256 in/out) +2% +5% +1% +5%

FA4 vs Triton — Latency (P50, lower is better)

Benchmark E2B E4B 26B-A4B 31B
Prefill TTFT -31% -22% -24% -22%
Long ctx decode (8K) -27% -19% -20% -16%
Decode b=1 +4% -1% +3% +1%
Decode b=8 +1% +1% +1% -1%

Key findings

  • FA4 wins all throughput scenarios across all 4 model sizes
  • Prefill/long-context improvement: +25-70% throughput, 22-31% faster TTFT
  • Short-context decode: neutral (~±2%) — weight-loading dominated per Amdahl's law
  • Uniform FA4 beats mixed FA3+FA4 by ~8% (kernel path uniformity)
  • mm_prefix two-call decomposition adds zero overhead for text-only requests; ~3% for multimodal batches (kernel launch overhead only)

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the v1 label May 9, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for bidirectional attention within multimodal token ranges for FlashAttention, primarily to support Gemma4 models. It updates the configuration to handle heterogeneous head dimensions and implements a correction mechanism in the forward pass to merge causal and bidirectional attention results. However, the current implementation has significant performance bottlenecks due to synchronous CPU-GPU transfers and nested loops in the forward pass. Additionally, the decomposition logic for merging attention states is mathematically incorrect, leading to double-counting of tokens and incorrect KV range indexing.

Comment thread vllm/v1/attention/backends/flash_attn.py Outdated
Comment thread vllm/v1/attention/backends/flash_attn.py
Comment thread vllm/v1/attention/backends/flash_attn.py Outdated
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 9, 2026

Hi @lucianommartins, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@lucianommartins
Copy link
Copy Markdown
Contributor Author

fyi @ywang96 @Isotr0py - a known limitation: mm_prefix attention overlap approximation

The two-call decomposition for PrefixLM bidirectional attention produces an approximation, not an exact result. The causal call covers KV positions [0, p] for a query at position p, while the non-causal correction covers [r_start, r_end]. For query tokens inside an mm_prefix range where p > r_start, these ranges overlap in [r_start, p]. The LSE-based merge (merge_attn_states) treats the two calls as independent partial results, which over-weights keys in the overlap region compared to the exact (causal OR mm_prefix) mask that the Triton backend computes in a single kernel via compute_kv_seq_mask.

Impact scope:

  • Text-only requests: none (mm_prefix never activates)
  • Multimodal decode: none (query tokens are always past the mm_prefix ranges)
  • Multimodal prefill: overlap grows linearly with position within the range — zero for the first image token, full range for the last. Affects intra-range attention distribution but not text token outputs

Correct fix: FA4's mask_mod callable supports the exact (causal OR mm_prefix) mask, but is currently blocked by interface.py:564-568 (mask_mod with aux_tensors is not yet supported for varlen sequences). Once this upstream limitation is resolved, the two-call decomposition can be replaced with a single mask_mod-based call.

@lucianommartins lucianommartins force-pushed the lucianommartins/gemma4-fa4 branch from b12a2b7 to 042d5ab Compare May 9, 2026 17:03
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 9, 2026

Documentation preview: https://vllm--42175.org.readthedocs.build/en/42175/

@mergify mergify Bot added the documentation Improvements or additions to documentation label May 9, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 9, 2026

Hi @lucianommartins, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Comment on lines +6578 to +6582
def _precompute_mm_prefix_indices(
self,
metadata: "FlashAttentionMetadata", # type: ignore[name-defined]
req_doc_ranges: dict[int, list[tuple[int, int]]],
) -> None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too FA specific, which should not be placed at model runner.

@Isotr0py
Copy link
Copy Markdown
Member

Isotr0py commented May 9, 2026

Correct fix: FA4's mask_mod callable supports the exact (causal OR mm_prefix) mask, but is currently blocked by interface.py:564-568 (mask_mod with aux_tensors is not yet supported for varlen sequences). Once this upstream limitation is resolved, the two-call decomposition can be replaced with a single mask_mod-based call.

BTW, which interface.py are you referring to?

Comment on lines +1128 to +1133
NOTE: The causal call and the non-causal correction call have
overlapping KV ranges for tokens within mm_prefix regions. The
LSE-based merge slightly over-weights keys in the overlap region
compared to the exact (causal OR mm_prefix) mask that the Triton
backend computes in a single kernel. This is a known approximation
until FA4 supports mask_mod with varlen sequences.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I worry that this can affect model's accuracy through numeric difference from overlayed weights...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's what I explained better on this #42175 (comment)

@lucianommartins
Copy link
Copy Markdown
Contributor Author

Correct fix: FA4's mask_mod callable supports the exact (causal OR mm_prefix) mask, but is currently blocked by interface.py:564-568 (mask_mod with aux_tensors is not yet supported for varlen sequences). Once this upstream limitation is resolved, the two-call decomposition can be replaced with a single mask_mod-based call.

BTW, which interface.py are you referring to?

hey @Isotr0py - it is the vllm/vllm_flash_attn/cute/interface.py (the FA4 CuTE-DSL entry point), lines 564-568.

@Isotr0py
Copy link
Copy Markdown
Member

It is the vllm/vllm_flash_attn/cute/interface.py (the FA4 CuTE-DSL entry point), lines 564-568.

But I think the latest FA4 entrypoint should have removed this limitation: https://github.com/Dao-AILab/flash-attention/blob/ab66326aaa4fe3529fbc00f3156f3a762dd3141b/flash_attn/cute/interface.py#L588-L614

Perhaps we should update our FA fork? cc @LucasWilkinson

@MatthewBonanni
Copy link
Copy Markdown
Member

MatthewBonanni commented May 15, 2026

I was under the impression that FA4 does not yet support headdim 512 on blackwell: Dao-AILab/flash-attention#2456

SM90 support was landed with vllm-project/flash-attention#130, presumably that's what you used for these benchmarks?

edit: ah yeah, I see they were run on H100. You'll need to update this PR so that it only attempts FA4 for this head size on SM90

@lucianommartins
Copy link
Copy Markdown
Contributor Author

It is the vllm/vllm_flash_attn/cute/interface.py (the FA4 CuTE-DSL entry point), lines 564-568.

But I think the latest FA4 entrypoint should have removed this limitation: https://github.com/Dao-AILab/flash-attention/blob/ab66326aaa4fe3529fbc00f3156f3a762dd3141b/flash_attn/cute/interface.py#L588-L614

Perhaps we should update our FA fork? cc @LucasWilkinson

hey @Isotr0py @LucasWilkinson - have you folks had a chance to look into it?

@Isotr0py
Copy link
Copy Markdown
Member

Isotr0py commented May 20, 2026

@MatthewBonanni Can we have Dao-AILab/flash-attention#2224 in our FA fork? I think this PR needs this upstream sync to allow proper bidirectional attention mask computation for Gemma4.

@lucianommartins
Copy link
Copy Markdown
Contributor Author

have you had a chance to take a look at it @MatthewBonanni @LucasWilkinson ?

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @lucianommartins.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 23, 2026
@coopslarhette
Copy link
Copy Markdown

I've been kinda surprised by gemma 4 family's relatively poor perf on vllm (as compared to qwen, gpt-oss etc), even with tuning max_num_seq, max_num_batched_tokens to our use case. perhaps this is why? would be really helpful to get this in if so!

@lucianommartins
Copy link
Copy Markdown
Contributor Author

hey @Isotr0py @MatthewBonanni @LucasWilkinson - have you had a chance to ingest Dao-AILab/flash-attention#2224 into vLLM?

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 29, 2026

Hi @lucianommartins, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@lucianommartins lucianommartins force-pushed the lucianommartins/gemma4-fa4 branch from ee7907c to 7045eaa Compare May 29, 2026 01:46
@lucianommartins
Copy link
Copy Markdown
Contributor Author

any news here @Isotr0py @MatthewBonanni @LucasWilkinson? it is really important to the Gemma4 devx for the vLLM community

@coopslarhette
Copy link
Copy Markdown

^^ :)

@MatthewBonanni
Copy link
Copy Markdown
Member

MatthewBonanni commented May 30, 2026

@lucianommartins I've made a sync PR (vllm-project/flash-attention#141) which captures that commit. vLLM-side is #44065. I'll work on getting it landed in the next couple of days

- Relax max_head_dim threshold from 256→512 on Hopper/Blackwell
  where FA4 is available; preserve TRITON_ATTN fallback elsewhere
- Force FA4 (flash_attn_version=4) for ALL Gemma4 layers when FA4
  is available, avoiding the mixed FA3+FA4 penalty that caused ~8%
  decode throughput regression vs uniform FA4
- Add mm_prefix (PrefixLM bidirectional attention) support to the
  FlashAttention backend via two-call decomposition:
  - First call: standard causal attention over full sequence
  - Second call: non-causal attention restricted to mm_prefix ranges
  - Merge via log-sum-exp rescaling (merge_attn_states)
  - Zero overhead for text-only batches; ~3% per-step cost for
    multimodal batches (kernel launch overhead only)
- Add mm_prefix_range_tensor field to FlashAttentionMetadata and
  wire it through gpu_model_runner._set_mm_prefix_range_for_metadata
- Enables FLASH_ATTN backend for all PrefixLM models (Gemma3,
  Gemma4, PaliGemma, Molmo2) that were previously restricted to
  Triton or FlexAttention

Benchmark results (H100 SXM, Gemma4 family, FA4 vs Triton baseline):
- Prefill throughput:    +25% to +40% across E2B/E4B/26B-A4B/31B
- Long context (8K):     +27% to +41% throughput
- Very long context:     +36% to +70% throughput (E2B/E4B/26B-A4B)
- Prefill TTFT latency:  -22% to -31% (faster time-to-first-token)
- Long ctx decode lat:   -16% to -27% (faster per-step at 8K context)
- Short ctx decode lat:  neutral (~±2%, weight-loading dominated)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
@lucianommartins
Copy link
Copy Markdown
Contributor Author

waiting for #44065 to be merged 🚀

@lucianommartins lucianommartins force-pushed the lucianommartins/gemma4-fa4 branch from 7045eaa to 9161dcc Compare May 31, 2026 18:30
@MatthewBonanni
Copy link
Copy Markdown
Member

@lucianommartins #44065 just landed 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants