diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 7b4fa9252b50..954dbcc65590 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -56,21 +56,24 @@ def verify_and_update_model_config(model_config: "ModelConfig") -> None: class Gemma4Config(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: - """Force unified attention backend for models with heterogeneous - head dimensions. - - Some Gemma4 variants use different head dimensions for - sliding window (head_dim) vs full attention (global_head_dim) layers. - When global_head_dim > 256, FlashAttention rejects those layers - (head_size <= 256 kernel limit), causing vLLM to select a different - backend for each layer type. This mixed-backend execution produces - numerical divergence and output corruption. - - The fix detects heterogeneous head dimensions from the model config - and forces TRITON_ATTN (which has no head_size ceiling) for all - layers when the user hasn't explicitly chosen a backend. - - TODO: Heterogeneous head_sizes (head_dim != global_head_dim) + """Allow per-layer attention backend selection for models with + heterogeneous head dimensions. + + Gemma 4 uses different head dimensions for sliding-window + (head_dim, typically 256) vs full-attention (global_head_dim, + typically 512) layers. Each ``Attention`` layer calls + ``get_attn_backend()`` with its own ``head_size``, and the + ``@cache``-decorated selector returns a distinct backend per + unique configuration. This means sliding-window layers + (head_dim=256) automatically pick FlashAttention while + full-attention layers (global_head_dim=512, which exceeds FA's + head_size<=256 kernel limit) fall back to the next-best + backend (e.g. Triton). + + Previously this method forced TRITON_ATTN globally, which + penalised the ~83% of layers that *can* run FlashAttention. + + NOTE: Heterogeneous head_sizes (head_dim != global_head_dim) require NixlConnector changes to support per-layer KV transfer with different head dimensions for prefill-decode disaggregation. """ @@ -78,31 +81,60 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: head_dim = getattr(hf_text_config, "head_dim", None) global_head_dim = getattr(hf_text_config, "global_head_dim", None) - # Only force Triton when head dimensions actually differ AND the - # larger one exceeds FlashAttention's kernel limit (head_size <= 256). - # This avoids unnecessary backend forcing on smaller models where - # the config carries global_head_dim but all layers can still use - # the same FA backend. - max_head_dim = max(head_dim or 0, global_head_dim or 0) if ( head_dim is not None and global_head_dim is not None and head_dim != global_head_dim - and max_head_dim > 256 - and vllm_config.attention_config.backend is None ): - from vllm.v1.attention.backends.registry import ( - AttentionBackendEnum, - ) + # Count sliding-window vs full-attention layers so users + # can see the expected backend split. + layer_types = getattr(hf_text_config, "layer_types", []) + n_full = sum(1 for t in layer_types if t == "full_attention") + n_sliding = len(layer_types) - n_full + + max_head = max(head_dim, global_head_dim) + explicit_backend = vllm_config.attention_config.backend + + if explicit_backend is None and max_head > 256: + # No user override and the larger head_dim exceeds + # FlashAttention's kernel limit (head_size <= 256). + # Per-layer selection will route sliding-window layers + # to FlashAttention and full-attention layers to a + # fallback backend (e.g. Triton). + logger.info( + "Gemma4 model has heterogeneous head dimensions " + "(head_dim=%d, global_head_dim=%d). %d sliding-window " + "layers will use FlashAttention; %d full-attention " + "layers will fall back to a compatible backend.", + head_dim, + global_head_dim, + n_sliding, + n_full, + ) + else: + logger.info( + "Gemma4 model has heterogeneous head dimensions " + "(head_dim=%d, global_head_dim=%d).", + head_dim, + global_head_dim, + ) - vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN - logger.info( - "Gemma4 model has heterogeneous head dimensions " - "(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN " - "backend to prevent mixed-backend numerical divergence.", - head_dim, - global_head_dim, - ) + # If the user explicitly forced a single backend, warn when + # it cannot handle the larger head dimension. The per-layer + # selector would raise at model init time anyway; surfacing + # the conflict early gives a clearer diagnostic. + if explicit_backend is not None: + backend_cls = explicit_backend.get_class() + if not backend_cls.supports_head_size(max_head): + logger.warning( + "Explicitly selected backend %s does not support " + "head_size=%d (required by full-attention layers). " + "Those layers will fail at runtime. Consider " + "removing --attention-backend to let each layer " + "auto-select the optimal backend.", + explicit_backend.name, + max_head, + ) class GptOssForCausalLMConfig(VerifyAndUpdateConfig):