Skip to content
8 changes: 5 additions & 3 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
KVCacheLayoutType,
get_dcp_local_seq_lens,
get_kv_cache_layout,
get_num_attention_heads_from_layers,
get_per_layer_parameters,
infer_global_hyperparameters,
split_decodes_and_prefills,
Expand Down Expand Up @@ -607,9 +608,10 @@ def __init__(
self.use_dcp and vllm_config.parallel_config.dcp_comm_backend == "a2a"
)

self.num_qo_heads = self.model_config.get_num_attention_heads(
self.vllm_config.parallel_config
)
# Compatible with models with non-uniform per-layer head counts.
self.num_qo_heads = get_num_attention_heads_from_layers(
vllm_config, layer_names
) or self.model_config.get_num_attention_heads(self.vllm_config.parallel_config)

self.num_kv_heads = self.kv_cache_spec.num_kv_heads
self.head_dim = self.kv_cache_spec.head_size
Expand Down
12 changes: 8 additions & 4 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.attention.backends.utils import (
get_kv_cache_layout,
get_num_attention_heads_from_layers,
)
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
Expand Down Expand Up @@ -139,9 +142,10 @@ def __init__(
self.block_size = kv_cache_spec.block_size

model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads(
vllm_config.parallel_config
)
# Compatible with models with non-uniform per-layer head counts.
self.num_heads_q = get_num_attention_heads_from_layers(
vllm_config, layer_names
) or model_config.get_num_attention_heads(vllm_config.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
self.headdim = model_config.get_head_size()

Expand Down
26 changes: 26 additions & 0 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,32 @@ def get_per_layer_parameters(
return per_layer_params


def get_num_attention_heads_from_layers(
vllm_config: VllmConfig, layer_names: list[str]
) -> int | None:
"""Per-TP-rank ``num_heads`` shared by the named Attention layers.

Use in metadata builders whose plan-time allocations depend on the
head count: the model-wide ``get_num_attention_heads()`` is wrong
for models with non-uniform per-layer head counts. All layers in
one attention group must agree on ``num_heads``; this is asserted.
Returns ``None`` when no matching Attention layer is found.
"""
attn_layers = get_layers_from_vllm_config(
vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
layer_names,
)
if not attn_layers:
return None
heads = {layer.impl.num_heads for layer in attn_layers.values()}
assert len(heads) == 1, (
f"All layers in one attention group must share num_heads; "
f"got {heads} for {layer_names}."
)
return heads.pop()


def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters],
) -> PerLayerParameters:
Expand Down
Loading