Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 66 additions & 34 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,53 +56,85 @@ def verify_and_update_model_config(model_config: "ModelConfig") -> None:
class Gemma4Config(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"""Force unified attention backend for models with heterogeneous
head dimensions.

Some Gemma4 variants use different head dimensions for
sliding window (head_dim) vs full attention (global_head_dim) layers.
When global_head_dim > 256, FlashAttention rejects those layers
(head_size <= 256 kernel limit), causing vLLM to select a different
backend for each layer type. This mixed-backend execution produces
numerical divergence and output corruption.

The fix detects heterogeneous head dimensions from the model config
and forces TRITON_ATTN (which has no head_size ceiling) for all
layers when the user hasn't explicitly chosen a backend.

TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
"""Allow per-layer attention backend selection for models with
heterogeneous head dimensions.

Gemma 4 uses different head dimensions for sliding-window
(head_dim, typically 256) vs full-attention (global_head_dim,
typically 512) layers. Each ``Attention`` layer calls
``get_attn_backend()`` with its own ``head_size``, and the
``@cache``-decorated selector returns a distinct backend per
unique configuration. This means sliding-window layers
(head_dim=256) automatically pick FlashAttention while
full-attention layers (global_head_dim=512, which exceeds FA's
head_size<=256 kernel limit) fall back to the next-best
backend (e.g. Triton).

Previously this method forced TRITON_ATTN globally, which
penalised the ~83% of layers that *can* run FlashAttention.

NOTE: Heterogeneous head_sizes (head_dim != global_head_dim)
require NixlConnector changes to support per-layer KV transfer
with different head dimensions for prefill-decode disaggregation.
"""
hf_text_config = vllm_config.model_config.hf_text_config
head_dim = getattr(hf_text_config, "head_dim", None)
global_head_dim = getattr(hf_text_config, "global_head_dim", None)

# Only force Triton when head dimensions actually differ AND the
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
# This avoids unnecessary backend forcing on smaller models where
# the config carries global_head_dim but all layers can still use
# the same FA backend.
max_head_dim = max(head_dim or 0, global_head_dim or 0)
if (
head_dim is not None
and global_head_dim is not None
and head_dim != global_head_dim
and max_head_dim > 256
and vllm_config.attention_config.backend is None
):
from vllm.v1.attention.backends.registry import (
AttentionBackendEnum,
)
# Count sliding-window vs full-attention layers so users
# can see the expected backend split.
layer_types = getattr(hf_text_config, "layer_types", [])
n_full = sum(1 for t in layer_types if t == "full_attention")
n_sliding = len(layer_types) - n_full

max_head = max(head_dim, global_head_dim)
explicit_backend = vllm_config.attention_config.backend

if explicit_backend is None and max_head > 256:
# No user override and the larger head_dim exceeds
# FlashAttention's kernel limit (head_size <= 256).
# Per-layer selection will route sliding-window layers
# to FlashAttention and full-attention layers to a
# fallback backend (e.g. Triton).
logger.info(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). %d sliding-window "
"layers will use FlashAttention; %d full-attention "
"layers will fall back to a compatible backend.",
head_dim,
global_head_dim,
n_sliding,
n_full,
)
else:
logger.info(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d).",
head_dim,
global_head_dim,
)

vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN
logger.info(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
"backend to prevent mixed-backend numerical divergence.",
head_dim,
global_head_dim,
)
# If the user explicitly forced a single backend, warn when
# it cannot handle the larger head dimension. The per-layer
# selector would raise at model init time anyway; surfacing
# the conflict early gives a clearer diagnostic.
if explicit_backend is not None:
backend_cls = explicit_backend.get_class()
if not backend_cls.supports_head_size(max_head):
logger.warning(
"Explicitly selected backend %s does not support "
"head_size=%d (required by full-attention layers). "
"Those layers will fail at runtime. Consider "
"removing --attention-backend to let each layer "
"auto-select the optimal backend.",
explicit_backend.name,
max_head,
)


class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
Expand Down
Loading