diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 2b206cbcf267..4abc014d1fbe 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -62,6 +62,7 @@ KVCacheLayoutType, get_dcp_local_seq_lens, get_kv_cache_layout, + get_num_attention_heads_from_layers, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills, @@ -607,9 +608,10 @@ def __init__( self.use_dcp and vllm_config.parallel_config.dcp_comm_backend == "a2a" ) - self.num_qo_heads = self.model_config.get_num_attention_heads( - self.vllm_config.parallel_config - ) + # Compatible with models with non-uniform per-layer head counts. + self.num_qo_heads = get_num_attention_heads_from_layers( + vllm_config, layer_names + ) or self.model_config.get_num_attention_heads(self.vllm_config.parallel_config) self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index aa7645f3e294..b68776375fc0 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -30,7 +30,10 @@ CommonAttentionMetadata, MultipleOf, ) -from vllm.v1.attention.backends.utils import get_kv_cache_layout +from vllm.v1.attention.backends.utils import ( + get_kv_cache_layout, + get_num_attention_heads_from_layers, +) from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash, @@ -139,9 +142,10 @@ def __init__( self.block_size = kv_cache_spec.block_size model_config = vllm_config.model_config - self.num_heads_q = model_config.get_num_attention_heads( - vllm_config.parallel_config - ) + # Compatible with models with non-uniform per-layer head counts. + self.num_heads_q = get_num_attention_heads_from_layers( + vllm_config, layer_names + ) or model_config.get_num_attention_heads(vllm_config.parallel_config) self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index d09c01eb9059..b73d17e8e5cc 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -136,6 +136,32 @@ def get_per_layer_parameters( return per_layer_params +def get_num_attention_heads_from_layers( + vllm_config: VllmConfig, layer_names: list[str] +) -> int | None: + """Per-TP-rank ``num_heads`` shared by the named Attention layers. + + Use in metadata builders whose plan-time allocations depend on the + head count: the model-wide ``get_num_attention_heads()`` is wrong + for models with non-uniform per-layer head counts. All layers in + one attention group must agree on ``num_heads``; this is asserted. + Returns ``None`` when no matching Attention layer is found. + """ + attn_layers = get_layers_from_vllm_config( + vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + layer_names, + ) + if not attn_layers: + return None + heads = {layer.impl.num_heads for layer in attn_layers.values()} + assert len(heads) == 1, ( + f"All layers in one attention group must share num_heads; " + f"got {heads} for {layer_names}." + ) + return heads.pop() + + def infer_global_hyperparameters( per_layer_params: dict[str, PerLayerParameters], ) -> PerLayerParameters: