Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test_areas/spec_decode.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ steps:
source_file_dependencies:
- vllm/v1/spec_decode/
- vllm/v1/worker/gpu/spec_decode/
- vllm/v1/attention/backends/
- vllm/transformers_utils/configs/speculators/
- tests/v1/e2e/spec_decode/
commands:
Expand Down
30 changes: 25 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6556,8 +6556,21 @@ def initialize_attn_backend(
assert len(self.attn_groups) == 0, "Attention backends are already initialized"

class AttentionGroupKey(NamedTuple):
"""Deduplication key for attention groups within a KV cache group.

Splits on per-rank ``num_heads_q`` in addition to backend + spec
so layers with different Q-head counts (e.g. a spec-decode draft
with fewer attention heads than its target) get separate metadata
builders. The builders' scratch (e.g. ``softmax_segm_*`` in
``triton_attn``, ``num_qo_heads`` in FlashInfer) is sized by
``num_heads_q`` and assumes uniformity within the group; see
``get_num_attention_heads_from_layers`` in
``vllm/v1/attention/backends/utils.py``.
"""

attn_backend: type[AttentionBackend]
kv_cache_spec: KVCacheSpec
num_heads_q: int

def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec,
Expand Down Expand Up @@ -6586,9 +6599,16 @@ def get_attn_backends_for_group(
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
key = (full_cls_name, layer_kv_cache_spec)
# Non-Attention layer types (e.g. Mamba1, ShortConv) do not
# expose ``num_heads``; fall back to 0 so they cluster as
# before. Such layers never coexist with Attention in a
# single KV cache group (different KVCacheSpec), so the
# fallback can never spuriously merge them with attention
# layers.
num_heads_q = getattr(layers[layer_name], "num_heads", 0)
key = (full_cls_name, layer_kv_cache_spec, num_heads_q)
attn_backends[key] = AttentionGroupKey(
attn_backend, layer_kv_cache_spec
attn_backend, layer_kv_cache_spec, num_heads_q
)
attn_backend_layers[key].append(layer_name)
return (
Expand All @@ -6601,11 +6621,11 @@ def create_attn_groups(
kv_cache_group_id: int,
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
for key, layer_names in attn_backends_map.items():
attn_group = AttentionGroup(
attn_backend,
key.attn_backend,
layer_names,
kv_cache_spec,
key.kv_cache_spec,
kv_cache_group_id,
)

Expand Down
Loading