From 72bbc3f96d5b6a873392537707532d28627867fb Mon Sep 17 00:00:00 2001 From: zhanda Date: Thu, 14 May 2026 10:27:55 -0400 Subject: [PATCH 1/3] [Bugfix] Source num_qo_heads from served Attention layers in FlashInfer/Triton metadata builders Signed-off-by: zhanda --- vllm/v1/attention/backends/flashinfer.py | 8 +++++--- vllm/v1/attention/backends/triton_attn.py | 12 ++++++++---- vllm/v1/attention/backends/utils.py | 24 +++++++++++++++++++++++ 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index dc4b1cccac7f..ad05fcf5772b 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -62,6 +62,7 @@ KVCacheLayoutType, get_dcp_local_seq_lens, get_kv_cache_layout, + get_num_attention_heads_from_layers, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills, @@ -607,9 +608,10 @@ def __init__( self.use_dcp and vllm_config.parallel_config.dcp_comm_backend == "a2a" ) - self.num_qo_heads = self.model_config.get_num_attention_heads( - self.vllm_config.parallel_config - ) + # Compatible with models with non-uniform per-layer head counts. + self.num_qo_heads = get_num_attention_heads_from_layers( + vllm_config, layer_names + ) or self.model_config.get_num_attention_heads(self.vllm_config.parallel_config) self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 96784ca1fe1e..402c46a483a4 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -30,7 +30,10 @@ CommonAttentionMetadata, MultipleOf, ) -from vllm.v1.attention.backends.utils import get_kv_cache_layout +from vllm.v1.attention.backends.utils import ( + get_kv_cache_layout, + get_num_attention_heads_from_layers, +) from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash, @@ -138,9 +141,10 @@ def __init__( self.block_size = kv_cache_spec.block_size model_config = vllm_config.model_config - self.num_heads_q = model_config.get_num_attention_heads( - vllm_config.parallel_config - ) + # Compatible with models with non-uniform per-layer head counts. + self.num_heads_q = get_num_attention_heads_from_layers( + vllm_config, layer_names + ) or model_config.get_num_attention_heads(vllm_config.parallel_config) self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 43cbcfec1844..d9d3adfdb44d 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -136,6 +136,30 @@ def get_per_layer_parameters( return per_layer_params +def get_num_attention_heads_from_layers( + vllm_config: VllmConfig, layer_names: list[str] +) -> int | None: + """Per-TP-rank ``num_heads`` shared by the named Attention layers. + + Use in metadata builders whose plan-time allocations depend on the + head count: the model-wide ``get_num_attention_heads()`` is wrong + for models with non-uniform per-layer head counts. All layers in + one attention group must agree on ``num_heads``; this is asserted. + Returns ``None`` when no matching Attention layer is found. + """ + attn_layers = get_layers_from_vllm_config( + vllm_config, AttentionLayerBase, layer_names + ) + if not attn_layers: + return None + heads = {layer.num_heads for layer in attn_layers.values()} + assert len(heads) == 1, ( + f"All layers in one attention group must share num_heads; " + f"got {heads} for {layer_names}." + ) + return heads.pop() + + def infer_global_hyperparameters( per_layer_params: dict[str, PerLayerParameters], ) -> PerLayerParameters: From 2c9c0edad4ae4ffef497e786b11882eee71a1fe4 Mon Sep 17 00:00:00 2001 From: zhanda Date: Thu, 21 May 2026 01:00:50 -0400 Subject: [PATCH 2/3] Fix mypy lint: silence type-abstract and attr-defined errors Signed-off-by: zhanda --- vllm/v1/attention/backends/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index c9f0b7633702..1c76930234da 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -148,11 +148,13 @@ def get_num_attention_heads_from_layers( Returns ``None`` when no matching Attention layer is found. """ attn_layers = get_layers_from_vllm_config( - vllm_config, AttentionLayerBase, layer_names + vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + layer_names, ) if not attn_layers: return None - heads = {layer.num_heads for layer in attn_layers.values()} + heads = {layer.num_heads for layer in attn_layers.values()} # type: ignore[attr-defined] assert len(heads) == 1, ( f"All layers in one attention group must share num_heads; " f"got {heads} for {layer_names}." From 913e30505e98e5c58a85443277867d8c1a6c3421 Mon Sep 17 00:00:00 2001 From: zhanda Date: Thu, 21 May 2026 01:03:58 -0400 Subject: [PATCH 3/3] Use layer.impl.num_heads for typed access Both impl and impl.num_heads are typed via AttentionLayerBase / AttentionImplBase, so we can drop the attr-defined ignore. Signed-off-by: zhanda --- vllm/v1/attention/backends/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 1c76930234da..b73d17e8e5cc 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -154,7 +154,7 @@ def get_num_attention_heads_from_layers( ) if not attn_layers: return None - heads = {layer.num_heads for layer in attn_layers.values()} # type: ignore[attr-defined] + heads = {layer.impl.num_heads for layer in attn_layers.values()} assert len(heads) == 1, ( f"All layers in one attention group must share num_heads; " f"got {heads} for {layer_names}."