From 0dac88ef7f755c7694910a2175eb7839bd1db486 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 4 Feb 2026 12:55:33 +0000 Subject: [PATCH 1/8] Support multiple KV-cache groups in speculative decoding drafters This enables models with multiple KV-cache groups (e.g., Gemma3, GPT-OSS MoE) to be used as drafters in speculative decoding. Key changes: - Refactored CommonAttentionMetadata handling to support a dictionary of metadata per KV-cache group ID (CommonAttnMetadataByGid) - Added per-group slot-mapping buffers for draft model inference - Introduced layer_names_to_kv_cache_gid mapping to correctly route attention layers to their corresponding KV-cache groups New test cases: - Gemma3 (270m): multiple KV-cache groups with mixed attention - GPT-OSS MoE (120b/20b): validates MoE layer resolution in spec decoding Fixes https://github.com/vllm-project/vllm/issues/33133 Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 49 ++++- tests/v1/spec_decode/test_eagle.py | 57 +++-- tests/v1/worker/test_utils.py | 21 +- vllm/v1/attention/backend.py | 6 + vllm/v1/attention/backends/utils.py | 2 - vllm/v1/core/kv_cache_utils.py | 13 +- vllm/v1/spec_decode/draft_model.py | 54 ++--- vllm/v1/spec_decode/eagle.py | 322 ++++++++++++++++------------ vllm/v1/worker/gpu_model_runner.py | 115 +++++----- vllm/v1/worker/utils.py | 13 ++ 10 files changed, 389 insertions(+), 263 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 4905a4120a2c..5984415f10fb 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -616,6 +616,15 @@ def test_mtp_correctness( cleanup_dist_env_and_memory() +def some_high_acceptance_metrics() -> dict: + return { + "sampling_config": greedy_sampling(), + "num_speculative_tokens": 3, + "expected_acceptance_len": 2.8 + 1, + "expected_acceptance_rate": 0.90, + } + + @dataclass class ArgsTest: target_model: str @@ -631,6 +640,8 @@ class ArgsTest: gpu_memory_utilization: float = 0.5 dataset: str = "test_prompts" num_prompts: int = 100 + # Some settings only get 100% acceptance_rate with VLLM_BATCH_INVARIANT=1 + use_batch_invariance: bool = False cases = [ @@ -652,12 +663,36 @@ class ArgsTest: expected_acceptance_len=2.8 + 1, expected_acceptance_rate=0.9, ), + # Multiple KV Cache groups + ArgsTest( + target_model="google/gemma-3-270m-it", + draft_model="google/gemma-3-270m-it", + sampling_config=greedy_sampling(), + num_speculative_tokens=3, + expected_acceptance_len=4, + expected_acceptance_rate=1, + # Without batch invariance, acceptance rate is ~86% + use_batch_invariance=True, + ), + # GPT-OSS MoE models with different target/draft sizes + # Tests MoE layer resolution in speculative decoding. + ArgsTest( + target_model="openai/gpt-oss-120b", + draft_model="openai/gpt-oss-20b", + # Leave some headroom for CUDA graph capture. + gpu_memory_utilization=0.85, + **some_high_acceptance_metrics(), + ), ] @pytest.mark.parametrize("args", cases) @pytest.mark.parametrize("enforce_eager", [True, False]) -def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool): +def test_draft_model_correctness( + args: ArgsTest, enforce_eager: bool, monkeypatch: pytest.MonkeyPatch +): + if args.use_batch_invariance: + monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1") assert_draft_model_correctness(args, enforce_eager) @@ -753,6 +788,8 @@ def test_draft_model_engine_args_rejects_invalid_tp_argname(): def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): """Compare the outputs using and not using speculative decoding. In the greedy decoding case, the outputs must match EXACTLY.""" + attention_config = {"backend": "FLASH_ATTN"} if args.use_batch_invariance else None + test_prompts: list[Messages] = get_messages( dataset=args.dataset, n=args.num_prompts ) @@ -773,6 +810,7 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): tensor_parallel_size=args.target_tensor_parallel_size, enforce_eager=enforce_eager, disable_log_stats=False, # enables get_metrics() + attention_config=attention_config, ) # we don't check the outputs, only check the metrics spec_llm.chat(test_prompts, args.sampling_config) @@ -804,15 +842,6 @@ def get_messages(dataset: str, n: int) -> list[Messages]: raise NotImplementedError(f"Dataset '{dataset}' not implemented") -def some_high_acceptance_metrics() -> dict: - return { - "sampling_config": greedy_sampling(), - "num_speculative_tokens": 3, - "expected_acceptance_len": 2.8 + 1, - "expected_acceptance_rate": 0.90, - } - - def test_merge_toks_kernel(): device = "cuda" merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2 diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 3158ff0bda95..52f7be518240 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -148,6 +148,11 @@ def test_prepare_next_token_ids(): block_size=16, device=device, ) + kv_cache_group_id = 0 + common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata} + # Mock connection of layer names to kv_cache_group_ids + proposer.attn_layer_names = ["layer0"] + proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id} expected_valid_sampled_tokens_count = torch.tensor( [2, 5, 0, 0], dtype=torch.int32, device=device @@ -155,7 +160,7 @@ def test_prepare_next_token_ids(): next_token_ids_from_padded, valid_sampled_tokens_count = ( proposer.prepare_next_token_ids_padded( - common_attn_metadata, + common_attn_metadata_by_gid, sampled_token_ids_tensor, mock_requests, mock_input_batch, @@ -248,10 +253,17 @@ def test_prepare_inputs(): ) proposer = _create_proposer("eagle", 1) - updated_metadata, token_indices = proposer.prepare_inputs( - common_attn_metadata, sampled_token_ids, num_draft_tokens + # Mock connection of layer names to kv_cache_group_ids + kv_cache_group_id = 0 + common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata} + proposer.attn_layer_names = ["layer0"] + proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id} + + updated_cm_by_gid, token_indices = proposer.prepare_inputs( + common_attn_metadata_by_gid, sampled_token_ids, num_draft_tokens ) + updated_metadata = updated_cm_by_gid[kv_cache_group_id] assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens) assert token_indices.shape[0] == expected_cu_num_tokens[-1].item() assert torch.equal(token_indices, expected_token_indices) @@ -306,9 +318,17 @@ def test_prepare_inputs_padded(): proposer = _create_proposer("eagle", num_speculative_tokens) - output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = ( + # Mock connection of layer names to kv_cache_group_ids + kv_cache_group_id = 0 + common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata} + proposer.attn_layer_names = ["layer0"] + proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id} + + output_cm_by_gid, token_indices_to_sample, num_rejected_tokens_gpu = ( proposer.prepare_inputs_padded( - common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count + common_attn_metadata_by_gid, + spec_decode_metadata, + valid_sampled_tokens_count, ) ) @@ -316,6 +336,7 @@ def test_prepare_inputs_padded(): expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device) assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected) + output_metadata = output_cm_by_gid[kv_cache_group_id] assert output_metadata.max_query_len == 3 assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc) assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) @@ -566,13 +587,15 @@ def create_deterministic_logits(token_ids): # Mock runner for attention metadata building proposer.runner = mock.MagicMock() - proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][ - 0 - ].get_metadata_builder.return_value = attn_metadata_builder + attn_group = mock.MagicMock() + attn_group.get_metadata_builder.return_value = attn_metadata_builder + attn_group.layer_names = proposer.attn_layer_names + proposer.runner.attn_groups = [[attn_group]] proposer._get_attention_metadata_builder = mock.MagicMock( return_value=attn_metadata_builder ) + kv_cache_group_id = 0 + common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata} result = proposer.propose( target_token_ids=target_token_ids, @@ -580,8 +603,8 @@ def create_deterministic_logits(token_ids): target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, last_token_indices=None, - common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, + cm_by_gid=common_attn_metadata_by_gid, ) assert result.shape == (batch_size, num_speculative_tokens) @@ -702,11 +725,11 @@ def create_deterministic_logits(token_ids, k: int): # Mock runner for attention metadata building. proposer.runner = mock.MagicMock() - proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder] - proposer.runner.attn_groups[0][ - 0 - ].get_metadata_builder.return_value = attn_metadata_builder + attn_group = mock.MagicMock() + attn_group.layer_names = proposer.attn_layer_names + attn_group.metadata_builders = [attn_metadata_builder] + attn_group.get_metadata_builder.return_value = attn_metadata_builder + proposer.runner.attn_groups = [[attn_group]] proposer._get_attention_metadata_builder = mock.MagicMock( return_value=attn_metadata_builder ) @@ -729,6 +752,8 @@ def create_deterministic_logits(token_ids, k: int): block_size=16, device=device, ) + kv_cache_group_id = 0 + common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata} sampling_metadata = mock.MagicMock() # Propose draft tokens. @@ -738,8 +763,8 @@ def create_deterministic_logits(token_ids, k: int): target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, last_token_indices=None, - common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, + cm_by_gid=common_attn_metadata_by_gid, ) assert result.shape == (batch_size, num_speculative_tokens) diff --git a/tests/v1/worker/test_utils.py b/tests/v1/worker/test_utils.py index 76f9a8f90f70..d861c7c3c598 100644 --- a/tests/v1/worker/test_utils.py +++ b/tests/v1/worker/test_utils.py @@ -3,6 +3,7 @@ import torch +from vllm.v1.core.kv_cache_utils import DRAFT_MODEL_PREFIX from vllm.v1.worker.utils import bind_kv_cache @@ -63,8 +64,8 @@ def test_bind_kv_cache_draft_model(default_vllm_config): layer_names = [ "model.layers.0.attn", "model.layers.1.attn", - "draft_model.layers.0.attn", - "draft_model.layers.1.attn", + f"{DRAFT_MODEL_PREFIX}.layers.0.attn", + f"{DRAFT_MODEL_PREFIX}.layers.1.attn", ] ctx = { layer_name: Attention(32, 128, 0.1, prefix=layer_name) @@ -76,17 +77,13 @@ def test_bind_kv_cache_draft_model(default_vllm_config): assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"] assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"] - assert ( - ctx["draft_model.layers.0.attn"].kv_cache[0] - is kv_cache["draft_model.layers.0.attn"] - ) - assert ( - ctx["draft_model.layers.1.attn"].kv_cache[0] - is kv_cache["draft_model.layers.1.attn"] - ) + draft_layer_0 = f"{DRAFT_MODEL_PREFIX}.layers.0.attn" + draft_layer_1 = f"{DRAFT_MODEL_PREFIX}.layers.1.attn" + assert ctx[draft_layer_0].kv_cache[0] is kv_cache[draft_layer_0] + assert ctx[draft_layer_1].kv_cache[0] is kv_cache[draft_layer_1] # caches are ordered by layer_index, interleaving target and draft model assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"] - assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"] + assert runner_kv_caches[1] is kv_cache[draft_layer_0] assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"] - assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"] + assert runner_kv_caches[3] is kv_cache[draft_layer_1] diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 13082608c47c..3d22d30df263 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -410,6 +410,12 @@ def unpadded( ) +# Mapping from KV cache group ID to its CommonAttentionMetadata. +# Each KV cache group may have different block_table_tensor and slot_mapping, +# while sharing other fields like query_start_loc and seq_lens. +CommonAttnMetadataByGid = dict[int, CommonAttentionMetadata] + + M = TypeVar("M") diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index dab298f1481c..dcd5c8925d02 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -828,7 +828,6 @@ def get_dcp_local_seq_lens( def extend_all_queries_by_1( common_attn_metadata: CommonAttentionMetadata, arange: torch.Tensor, - new_slot_mapping: torch.Tensor, ) -> CommonAttentionMetadata: """ Creates a new CommonAttentionMetadata with all query lengths increased by 1. @@ -852,7 +851,6 @@ def extend_all_queries_by_1( # All query lens increase by 1, so max query len increases by 1 max_query_len=cad.max_query_len + 1, max_seq_len=cad.max_seq_len + 1, - slot_mapping=new_slot_mapping, ) return new_cad diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index fd12dfe045a4..57eab29189d0 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -74,6 +74,10 @@ def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash: logger = init_logger(__name__) +# Prefix for draft model layer names in the KV cache (e.g. "draft_model.layers.0.attn"). +# Must match the prefix passed to get_model() when loading the draft model. +DRAFT_MODEL_PREFIX = "draft_model" + # The hash seed for the first block of any prefix block sequence. # # We use a random value to avoid hash collisions or PYTHONHASHSEED environment @@ -1011,12 +1015,15 @@ def _get_kv_cache_groups_uniform_page_size( Returns: The generated KVCacheGroupSpecs """ - # Group all layers by kv_cache_spec. + # Group all layers by kv_cache_spec AND whether it's a draft or target layer. # E.g., 2 full attention layers and 3 sliding window attention layers, # -> (full.0, full.1), (sw.0, sw.1, sw.2). - same_type_layers: dict[KVCacheSpec, list[str]] = defaultdict(list) + # Draft and target layers are kept in separate groups to prevent KV cache + # cross-contamination during speculative decoding. + same_type_layers: dict[tuple[KVCacheSpec, bool], list[str]] = defaultdict(list) for layer_name, layer_spec in kv_cache_spec.items(): - same_type_layers[layer_spec].append(layer_name) + is_draft_layer = layer_name.startswith(DRAFT_MODEL_PREFIX) + same_type_layers[(layer_spec, is_draft_layer)].append(layer_name) # Split each group into smaller groups, to make the number of layers in each # group identical. Add padding to the last group of each type if necessary. diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 18e98b267612..c7148e1eadea 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -9,10 +9,12 @@ from vllm.model_executor.layers.attention import Attention from vllm.model_executor.model_loader import get_model from vllm.triton_utils import tl, triton +from vllm.v1.attention.backend import CommonAttnMetadataByGid from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, extend_all_queries_by_1, ) +from vllm.v1.core.kv_cache_utils import DRAFT_MODEL_PREFIX from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer logger = init_logger(__name__) @@ -37,10 +39,6 @@ def __init__( self._raise_if_vocab_size_mismatch() self._raise_if_draft_tp_mismatch() - def _block_size(self) -> int: - builder = self._get_attention_metadata_builder() - return builder.kv_cache_spec.block_size - def _raise_if_multimodal(self): if self.supports_mm_inputs: raise NotImplementedError( @@ -88,9 +86,10 @@ def set_inputs_first_pass( next_token_ids: torch.Tensor, target_positions: torch.Tensor, last_token_indices: torch.Tensor | None, - cad: CommonAttentionMetadata, num_rejected_tokens_gpu: torch.Tensor | None, - ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: + cm_by_gid: CommonAttnMetadataByGid, + ) -> tuple[int, torch.Tensor, CommonAttnMetadataByGid]: + cad = self.pick_first_layer_common_attn_metadata(cm_by_gid) batch_size = cad.batch_size() grid = (batch_size,) start_locs = cad.query_start_loc[:-1] @@ -126,26 +125,28 @@ def set_inputs_first_pass( rejected_tok_fill=0, ) - # recompute slot mapping - new_slot_mapping = compute_new_slot_mapping( - cad=cad, - new_positions=self.positions[:num_tokens], - is_rejected_token_mask=is_rejected_tok, - block_size=self._block_size(), - max_model_len=self.max_model_len, - ) - # update common_attn_metadata - new_cad: CommonAttentionMetadata = extend_all_queries_by_1( - cad, - arange=self.arange, - new_slot_mapping=new_slot_mapping, - ) + # Update slot_mappings across all KV cache groups + new_cm_by_gid: CommonAttnMetadataByGid = {} + for gid, cm in cm_by_gid.items(): + slot_mapping = compute_new_slot_mapping( + cad=cm, + new_positions=self.positions[:num_tokens], + is_rejected_token_mask=is_rejected_tok, + block_size=self._get_metadata_builder(gid).kv_cache_spec.block_size, + max_model_len=self.max_model_len, + block_table_tensor=cm.block_table_tensor, + ) + new_cm = cm.replace(slot_mapping=slot_mapping) + new_cm = extend_all_queries_by_1(new_cm, arange=self.arange) + new_cm_by_gid[gid] = new_cm + # Pick the updated CAM to compute new_last_token_indices + new_cad = self.pick_first_layer_common_attn_metadata(new_cm_by_gid) new_last_token_indices = new_cad.query_start_loc[1:] - 1 if num_rejected_tokens_gpu is not None: new_last_token_indices -= num_rejected_tokens_gpu - return num_tokens, new_last_token_indices, new_cad + return num_tokens, new_last_token_indices, new_cm_by_gid def load_model(self, target_model: Any) -> None: """Takes target_model to satisfy the type checker.""" @@ -167,8 +168,10 @@ def load_model(self, target_model: Any) -> None: draft_vllm_config.parallel_config.tensor_parallel_size, draft_vllm_config.parallel_config.rank, ) - with set_model_tag("draft_model"): - self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model") + with set_model_tag(DRAFT_MODEL_PREFIX): + self.model = get_model( + vllm_config=draft_vllm_config, prefix=DRAFT_MODEL_PREFIX + ) # This must be computed after loading the draft model # because that mutates the forward_context of the vllm_config @@ -210,8 +213,9 @@ def compute_new_slot_mapping( is_rejected_token_mask: torch.Tensor, block_size: int, max_model_len: int, + block_table_tensor: torch.Tensor, ): - batch_size, n_blocks_per_req = cad.block_table_tensor.shape + batch_size, n_blocks_per_req = block_table_tensor.shape req_indices = torch.arange(batch_size, device=cad.query_start_loc.device) req_indices = torch.repeat_interleave( req_indices, cad.naive_query_lens() + 1, output_size=len(new_positions) @@ -222,7 +226,7 @@ def compute_new_slot_mapping( block_table_indices = ( req_indices * n_blocks_per_req + clamped_positions // block_size ) - block_nums = cad.block_table_tensor.view(-1)[block_table_indices] + block_nums = block_table_tensor.view(-1)[block_table_indices] block_offsets = clamped_positions % block_size new_slot_mapping = block_nums * block_size + block_offsets # Mask out the position ids that exceed the max model length. diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 45680a7965bb..a919f9c7a98b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast +from collections import defaultdict from dataclasses import replace +from functools import cached_property from importlib.util import find_spec from typing import cast @@ -28,8 +30,10 @@ from vllm.triton_utils import triton from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backend import ( + AttentionMetadata, AttentionMetadataBuilder, CommonAttentionMetadata, + CommonAttnMetadataByGid, ) from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.tree_attn import ( @@ -38,7 +42,6 @@ ) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher -from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -49,6 +52,7 @@ from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.utils import layer_names_to_kv_cache_group_id logger = init_logger(__name__) @@ -169,8 +173,9 @@ def __init__( with_numpy=True, ) - self._slot_mapping_buffer = torch.zeros( - self.max_num_tokens, dtype=torch.int64, device=device + # Per-gid slot-mapping buffers (lazily created via defaultdict). + self._slot_mapping_buffer_by_gid: defaultdict[int, torch.Tensor] = defaultdict( + self.new_slot_mapping_buffer ) # Determine allowed attention backends once during initialization. @@ -231,6 +236,21 @@ def __init__( 1, len(self.tree_choices) + 1, device=device, dtype=torch.int32 ).repeat(max_batch_size, 1) + @cached_property + def layer_names_to_kv_cache_gid(self) -> dict[str, int]: + return layer_names_to_kv_cache_group_id(self.runner.attn_groups) + + def pick_first_layer_common_attn_metadata( + self, cm_by_gid: CommonAttnMetadataByGid + ) -> CommonAttentionMetadata: + """Pick the CommonAttentionMetadata for the drafter's first attention layer. + + The drafter may have attention layers in different KV cache groups. + This method picks the CAM corresponding to the first attention layer. + """ + gid = self.layer_names_to_kv_cache_gid[self.attn_layer_names[0]] + return cm_by_gid[gid] + def _get_positions(self, num_tokens: int): if self.uses_mrope: return self.mrope_positions[:, :num_tokens] @@ -251,22 +271,49 @@ def _set_positions(self, num_tokens: int, positions: torch.Tensor): positions = positions[0] self.positions[:num_tokens] = positions + def new_slot_mapping_buffer(self) -> torch.Tensor: + """Create a new slot-mapping buffer for one KV cache group.""" + return torch.zeros( + self.max_num_tokens, + dtype=torch.int64, + device=self.hidden_states.device, + ) + def _get_slot_mapping( self, num_tokens: int, - slot_mapping: torch.Tensor | None = None, + cm_by_gid: CommonAttnMetadataByGid, ) -> dict[str, torch.Tensor]: - """Return slot_mapping dict for EAGLE layers. + """Return slot_mapping dict for EAGLE/draft layers during inference. - If slot_mapping is provided, copies it into the buffer first. + Copies per-group slot mappings from CommonAttentionMetadata into + per-gid buffers and returns views into those buffers. """ - if slot_mapping is not None: - num_actual = slot_mapping.shape[0] - self._slot_mapping_buffer[:num_actual].copy_(slot_mapping) - if num_tokens > num_actual: - self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID) - - view = self._slot_mapping_buffer[:num_tokens] + # COPY + for gid, cm in cm_by_gid.items(): + buf = self._slot_mapping_buffer_by_gid[gid] + src = cm.slot_mapping + n = min(num_tokens, src.shape[0]) + buf[:n].copy_(src[:n]) + if n < num_tokens: + buf[n:num_tokens].fill_(PADDING_SLOT_ID) + + # READ + layer_to_buffer: dict[str, torch.Tensor] = {} + layer_names = [*self.attn_layer_names, *self.indexer_layer_names] + for layer_name in layer_names: + gid = self.layer_names_to_kv_cache_gid[layer_name] + buf = self._slot_mapping_buffer_by_gid[gid][:num_tokens] + layer_to_buffer[layer_name] = buf + + return layer_to_buffer + + def _get_slot_mapping_dummy_run(self, num_tokens: int) -> dict[str, torch.Tensor]: + """Return slot_mapping dict for EAGLE/draft layers during dummy run. + Uses per-gid buffer (gid 0). + """ + buf = self._slot_mapping_buffer_by_gid[0] + view = buf[:num_tokens] return {name: view for name in self.attn_layer_names + self.indexer_layer_names} def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: @@ -297,16 +344,15 @@ def propose( # [batch_size] next_token_ids: torch.Tensor, last_token_indices: torch.Tensor | None, - common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, + # CommonAttentionMetadata for each KV cache group ID. + cm_by_gid: CommonAttnMetadataByGid, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, num_rejected_tokens_gpu: torch.Tensor | None = None, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, ) -> torch.Tensor: - batch_size = common_attn_metadata.batch_size() - if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) target_hidden_states = self.model.combine_hidden_states( @@ -314,27 +360,35 @@ def propose( ) assert target_hidden_states.shape[-1] == self.hidden_size - num_tokens, last_token_indices, common_attn_metadata = ( - self.set_inputs_first_pass( - target_token_ids=target_token_ids, - next_token_ids=next_token_ids, - target_positions=target_positions, - last_token_indices=last_token_indices, - cad=common_attn_metadata, - num_rejected_tokens_gpu=num_rejected_tokens_gpu, - ) + assert self.runner is not None + + # Per-group block tables, block sizes, slot mappings (from common_attn_metadata) + num_tokens, last_token_indices, cm_by_gid = self.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + last_token_indices=last_token_indices, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + cm_by_gid=cm_by_gid, ) + common_attn_metadata = self.pick_first_layer_common_attn_metadata(cm_by_gid) + batch_size = common_attn_metadata.batch_size() - assert self.runner is not None + # Build attention metadata for each KV cache group + # then build it for each layer. + attn_metadata_by_gid: dict[int, AttentionMetadata] = {} + for gid, cm in cm_by_gid.items(): + builder = self._get_metadata_builder(gid) + attn_metadata = builder.build_for_drafting( + common_attn_metadata=cm, draft_index=0 + ) + attn_metadata_by_gid[gid] = attn_metadata - if self.attn_metadata_builder is None: - attn_metadata_builder = self._get_attention_metadata_builder() - else: - attn_metadata_builder = self.attn_metadata_builder + per_layer_attn_metadata: dict[str, AttentionMetadata] = {} + for layer_name in self.attn_layer_names: + gid = self.layer_names_to_kv_cache_gid[layer_name] + per_layer_attn_metadata[layer_name] = attn_metadata_by_gid[gid] - attn_metadata = attn_metadata_builder.build_for_drafting( - common_attn_metadata=common_attn_metadata, draft_index=0 - ) # FIXME: support hybrid kv for draft model (remove separate indexer) if self.draft_indexer_metadata_builder: draft_indexer_metadata = ( @@ -345,11 +399,6 @@ def propose( ) else: draft_indexer_metadata = None - # At this moment, we assume all eagle layers belong to the same KV - # cache group, thus using the same attention metadata. - per_layer_attn_metadata = {} - for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata for layer_name in self.indexer_layer_names: assert draft_indexer_metadata is not None @@ -401,7 +450,8 @@ def propose( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, slot_mapping=self._get_slot_mapping( - num_input_tokens, common_attn_metadata.slot_mapping + num_tokens=num_input_tokens, + cm_by_gid=cm_by_gid, ), ): ret_hidden_states = self.model(**model_kwargs) @@ -442,6 +492,7 @@ def propose( hidden_states=hidden_states, common_attn_metadata=common_attn_metadata, slot_mappings=slot_mappings, + cm_by_gid=cm_by_gid, ) # [batch_size, num_tree_tokens] return torch.cat(draft_token_ids_list, dim=1) @@ -538,38 +589,55 @@ def propose( if common_attn_metadata._num_computed_tokens_cpu is not None: common_attn_metadata._num_computed_tokens_cpu += 1 - # Compute the slot mapping. - block_size = attn_metadata_builder.kv_cache_spec.block_size - if self.uses_mrope: - # all dimensions of positions are the same - block_numbers = clamped_positions[0] // block_size - else: - block_numbers = clamped_positions // block_size - block_ids = common_attn_metadata.block_table_tensor.gather( - dim=1, index=block_numbers.view(-1, 1) - ) - block_ids = block_ids.view(-1) - if self.uses_mrope: - common_attn_metadata.slot_mapping = ( - block_ids * block_size + clamped_positions[0] % block_size + # Update all other CommonAttentionMetadata objects with the updated + # information computed above. + for gid, cm in cm_by_gid.items(): + cm_by_gid[gid] = common_attn_metadata.replace( + slot_mapping=cm.slot_mapping, + block_table_tensor=cm.block_table_tensor, ) - else: - common_attn_metadata.slot_mapping = ( - block_ids * block_size + clamped_positions % block_size + + # Compute the slot mapping for each kv-cache group + attn_metadata_by_gid = {} + for gid, cm in cm_by_gid.items(): + blk_table_tensor = cm.block_table_tensor + + # Compute per-group block_size and block_numbers + builder = self._get_metadata_builder(gid) + group_block_size = builder.kv_cache_spec.block_size + if self.uses_mrope: + group_block_numbers = clamped_positions[0] // group_block_size + else: + group_block_numbers = clamped_positions // group_block_size + + block_ids = blk_table_tensor.gather( + dim=1, index=group_block_numbers.view(-1, 1) ) - # Mask out the slot mappings that exceed the max model length. - # Otherwise, the KV cache will be inadvertently updated with the - # padding tokens. - common_attn_metadata.slot_mapping.masked_fill_( - exceeds_max_model_len, PADDING_SLOT_ID - ) + block_ids = block_ids.view(-1) + if self.uses_mrope: + slot_mapping = ( + block_ids * group_block_size + + clamped_positions[0] % group_block_size + ) + else: + slot_mapping = ( + block_ids * group_block_size + + clamped_positions % group_block_size + ) + slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + + new_cm = cm_by_gid[gid].replace(slot_mapping=slot_mapping) + attn_metadata = builder.build_for_drafting( + common_attn_metadata=new_cm, + draft_index=token_index + 1, + ) + cm_by_gid[gid] = new_cm + attn_metadata_by_gid[gid] = attn_metadata # Rebuild attention metadata - attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore - common_attn_metadata=common_attn_metadata, draft_index=token_index + 1 - ) for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata + gid = self.layer_names_to_kv_cache_gid[layer_name] + per_layer_attn_metadata[layer_name] = attn_metadata_by_gid[gid] # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids @@ -600,7 +668,8 @@ def propose( num_tokens_across_dp=batch_size_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, slot_mapping=self._get_slot_mapping( - input_batch_size, common_attn_metadata.slot_mapping + num_tokens=input_batch_size, + cm_by_gid=cm_by_gid, ), ): ret_hidden_states = self.model(**model_kwargs) @@ -625,9 +694,10 @@ def set_inputs_first_pass( next_token_ids: torch.Tensor, target_positions: torch.Tensor, last_token_indices: torch.Tensor | None, - cad: CommonAttentionMetadata, num_rejected_tokens_gpu: torch.Tensor | None, - ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: + cm_by_gid: CommonAttnMetadataByGid, + ) -> tuple[int, torch.Tensor, CommonAttnMetadataByGid]: + cad = self.pick_first_layer_common_attn_metadata(cm_by_gid) if last_token_indices is None: last_token_indices = cad.query_start_loc[1:] - 1 @@ -644,7 +714,7 @@ def set_inputs_first_pass( target_positions = target_positions[0] self._set_positions(num_tokens, target_positions) - return num_tokens, last_token_indices, cad + return num_tokens, last_token_indices, cm_by_gid def model_returns_tuple(self) -> bool: return self.method not in ("mtp", "draft_model") @@ -684,7 +754,7 @@ def prepare_next_token_ids_cpu( def prepare_next_token_ids_padded( self, - common_attn_metadata: CommonAttentionMetadata, + cm_by_gid: CommonAttnMetadataByGid, sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, @@ -697,6 +767,7 @@ def prepare_next_token_ids_padded( is not sampled and comes from `request.get_token_id()` instead. This is denoted the "backup" token id. It also counts rejected tokens via `sampled_token_ids`. """ + common_attn_metadata = self.pick_first_layer_common_attn_metadata(cm_by_gid) # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs self.backup_next_token_ids.np[:num_reqs] = np.array( @@ -742,10 +813,10 @@ def prepare_next_token_ids_padded( def prepare_inputs_padded( self, - common_attn_metadata: CommonAttentionMetadata, + cm_by_gid: CommonAttnMetadataByGid, spec_decode_metadata: SpecDecodeMetadata, valid_sampled_tokens_count: torch.Tensor, - ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + ) -> tuple[CommonAttnMetadataByGid, torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding It updates the common_attn_metadata for speculative decoding, @@ -754,6 +825,9 @@ def prepare_inputs_padded( used as padding and filtered out later by `token_indices_to_sample`. No blocking CPU operations should be introduced in this function. """ + # Pick the CAM for the drafter's first attention layer as the base + common_attn_metadata = self.pick_first_layer_common_attn_metadata(cm_by_gid) + num_reqs = common_attn_metadata.num_reqs device = valid_sampled_tokens_count.device @@ -779,7 +853,8 @@ def prepare_inputs_padded( total_num_tokens = query_start_loc_cpu[-1].item() - spec_common_attn_metadata = CommonAttentionMetadata( + # First create the new base CAM with transformed shared fields + new_cam = CommonAttentionMetadata( query_start_loc=common_attn_metadata.query_start_loc, seq_lens=common_attn_metadata.seq_lens, query_start_loc_cpu=query_start_loc_cpu, @@ -789,14 +864,24 @@ def prepare_inputs_padded( num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), - block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens], + # block-table and slot-mapping are set below + block_table_tensor=torch.empty(0), + slot_mapping=torch.empty(0), causal=True, dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) + # Then create new CAMs for all groups, preserving each group's + # block_table_tensor and slot_mapping + new_cm_by_gid: CommonAttnMetadataByGid = {} + for gid, cm in cm_by_gid.items(): + new_cm_by_gid[gid] = new_cam.replace( + block_table_tensor=cm.block_table_tensor, + slot_mapping=cm.slot_mapping[:total_num_tokens], + ) + return ( - spec_common_attn_metadata, + new_cm_by_gid, token_indices_to_sample, num_rejected_tokens_gpu, ) @@ -811,6 +896,7 @@ def propose_tree( # [num_tokens, hidden_size] hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, + cm_by_gid: CommonAttnMetadataByGid, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, @@ -936,7 +1022,8 @@ def propose_tree( num_tokens=num_input_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, slot_mapping=self._get_slot_mapping( - num_input_tokens, attn_metadata.slot_mapping + num_tokens=num_input_tokens, + cm_by_gid=cm_by_gid, ), ): last_hidden_states, hidden_states = self.model( @@ -976,16 +1063,19 @@ def propose_tree( def prepare_inputs( self, - common_attn_metadata: CommonAttentionMetadata, + cm_by_gid: CommonAttnMetadataByGid, sampled_token_ids: list[list[int]], num_draft_tokens: list[int], - ) -> tuple[CommonAttentionMetadata, torch.Tensor]: + ) -> tuple[CommonAttnMetadataByGid, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding. It updates to the common_attn_metadata to account for the rejected tokens (and newly sampled tokens). It also returns the token indices of the tokens that should be fed to the speculator. """ + # Pick the CAM for the drafter's first attention layer as the base + common_attn_metadata = self.pick_first_layer_common_attn_metadata(cm_by_gid) + # E.g. # common_attn_metadata.query_start_loc{_cpu}: # [0, q1, q1 + q2, q1 + q2 + q3] @@ -1058,7 +1148,8 @@ def prepare_inputs( token_indices_np = token_offsets + old_query_start_locs_expanded token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True) - spec_common_attn_metadata = CommonAttentionMetadata( + # First create the new base CAM with transformed shared fields + new_cam = CommonAttentionMetadata( query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), query_start_loc_cpu=new_query_start_loc_cpu, @@ -1068,13 +1159,23 @@ def prepare_inputs( num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), max_seq_len=new_seq_lens_cpu.max().item(), - block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping[token_indices], + # block-table and slot-mapping are set below + block_table_tensor=torch.empty(0), + slot_mapping=torch.empty(0), causal=True, dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) - return spec_common_attn_metadata, token_indices + # Then create new CAMs for all groups, preserving each group's + # block_table_tensor and slot_mapping + new_cm_by_gid: CommonAttnMetadataByGid = {} + for gid, cm in cm_by_gid.items(): + new_cm_by_gid[gid] = new_cam.replace( + block_table_tensor=cm.block_table_tensor, + slot_mapping=cm.slot_mapping[token_indices], + ) + + return new_cm_by_gid, token_indices def get_model_name(self, model: nn.Module) -> str: if hasattr(model, "module"): # multi-GPU @@ -1311,7 +1412,9 @@ def dummy_run( and slot_mappings is not None and self.attn_layer_names[0] in slot_mappings ): - slot_mapping_dict = self._get_slot_mapping(num_input_tokens) + slot_mapping_dict = self._get_slot_mapping_dummy_run( + num_tokens=num_input_tokens + ) else: slot_mapping_dict = slot_mappings or {} @@ -1339,30 +1442,8 @@ def dummy_run( kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] self.model(**kwargs) - def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: - """Find and return the attention metadata builders for EAGLE layers. - - Returns: - The metadata builders for EAGLE layers. - - Raises: - AssertionError: If no metadata builders are found for EAGLE layers. - """ - builder = None - chosen_layer = self.attn_layer_names[0] - - for kv_cache_group in self.runner.attn_groups: - for attn_group in kv_cache_group: - if chosen_layer in attn_group.layer_names: - builder = attn_group.get_metadata_builder() - break - if builder is not None: - break - - assert builder is not None, ( - "Failed to find attention metadata builder for EAGLE layers." - ) - return builder + def _get_metadata_builder(self, kv_cache_group_id: int) -> AttentionMetadataBuilder: + return self.runner.attn_groups[kv_cache_group_id][0].get_metadata_builder() def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool: """ @@ -1380,29 +1461,6 @@ def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool: use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True) return use_aux_hidden_state - def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: - """ - Validate that all drafting layers belong to the same KVCacheGroup. - Need this assumption to ensure all drafting layers can use the - same AttentionMetadata. - May extend to multiple AttentionMetadata in the future. - """ - kv_cache_groups: dict[str, int] = {} - for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): - for layer_name in kv_cache_group.layer_names: - kv_cache_groups[layer_name] = id - assert ( - len( - set( - [ - kv_cache_groups[layer_name] - for layer_name in self.attn_layer_names - ] - ) - ) - == 1 - ), "All drafting layers should belong to the same kv cache group" - def _pad_batch_across_dp( self, num_tokens_unpadded: int, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 39ac6bce820e..c78b157ba123 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -110,6 +110,7 @@ AttentionMetadataBuilder, AttentionType, CommonAttentionMetadata, + CommonAttnMetadataByGid, MultipleOf, ) from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder @@ -316,13 +317,13 @@ class ExecuteModelState(NamedTuple): scheduler_output: "SchedulerOutput" logits: torch.Tensor spec_decode_metadata: SpecDecodeMetadata | None - spec_decode_common_attn_metadata: CommonAttentionMetadata | None hidden_states: torch.Tensor sample_hidden_states: torch.Tensor aux_hidden_states: list[torch.Tensor] | None ec_connector_output: ECConnectorOutput | None cudagraph_stats: CUDAGraphStat | None slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None + cm_by_gid: CommonAttnMetadataByGid class GPUModelRunner( @@ -1680,13 +1681,10 @@ def _build_attention_metadata( num_scheduled_tokens: dict[str, int] | None = None, cascade_attn_prefix_lens: list[list[int]] | None = None, slot_mappings: dict[int, torch.Tensor] | None = None, - ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: - """ - :return: tuple[attn_metadata, spec_decode_common_attn_metadata] - """ + ) -> tuple[PerLayerAttnMetadata, CommonAttnMetadataByGid]: # Attention metadata is not needed for attention free models if len(self.kv_cache_config.kv_cache_groups) == 0: - return {}, None + return {}, dict() num_tokens_padded = num_tokens_padded or num_tokens num_reqs_padded = num_reqs_padded or num_reqs @@ -1784,6 +1782,10 @@ def _get_block_table(kv_cache_gid: int): tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata ] = {} + # CommonAttentionMetadata objects for each KV cache group. + # These are passed to the drafter to support multi-group KV cache in drafters. + cm_by_gid: CommonAttnMetadataByGid = {} + def _build_attn_group_metadata( kv_cache_gid: int, attn_gid: int, @@ -1847,7 +1849,6 @@ def _build_attn_group_metadata( # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. - spec_decode_common_attn_metadata = None for kv_cache_gid, kv_cache_group in enumerate(kv_cache_groups): cm = copy(cm_base) # shallow copy @@ -1859,16 +1860,10 @@ def _build_attn_group_metadata( num_reqs_padded, for_cudagraph_capture=for_cudagraph_capture, ) - if kv_cache_gid > 0: - cm.block_table_tensor = _get_block_table(kv_cache_gid) - cm.slot_mapping = slot_mappings[kv_cache_gid] - - if self.speculative_config and spec_decode_common_attn_metadata is None: - if isinstance(self.drafter, EagleProposer): - if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: - spec_decode_common_attn_metadata = cm - else: - spec_decode_common_attn_metadata = cm + cm.block_table_tensor = _get_block_table(kv_cache_gid) + cm.slot_mapping = slot_mappings[kv_cache_gid] + + cm_by_gid[kv_cache_gid] = cm for attn_gid in range(len(self.attn_groups[kv_cache_gid])): if ubatch_slices is not None: @@ -1898,17 +1893,17 @@ def _build_attn_group_metadata( for _metadata in attn_metadata.values(): _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] - if spec_decode_common_attn_metadata is not None and ( + if cm_by_gid and ( num_reqs != num_reqs_padded or num_tokens != num_tokens_padded ): # Currently the drafter still only uses piecewise cudagraphs (and modifies # the attention metadata in directly), and therefore does not want to use # padded attention metadata. - spec_decode_common_attn_metadata = ( - spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) - ) + cm_by_gid = { + gid: cm.unpadded(num_tokens, num_reqs) for gid, cm in cm_by_gid.items() + } - return attn_metadata, spec_decode_common_attn_metadata + return attn_metadata, cm_by_gid def _compute_cascade_attn_prefix_lens( self, @@ -3113,8 +3108,8 @@ def _determine_batch_execution_and_padding( has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) - dispatch_cudagraph = ( - lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch( + dispatch_cudagraph = lambda num_tokens, disable_full: ( + self.cudagraph_dispatcher.dispatch( num_tokens=num_tokens, has_lora=has_lora, uniform_decode=uniform_decode, @@ -3472,20 +3467,18 @@ def execute_model( ubatch_slices=ubatch_slices_padded, ) - attn_metadata, spec_decode_common_attn_metadata = ( - self._build_attention_metadata( - num_tokens=num_tokens_unpadded, - num_tokens_padded=num_tokens_padded if pad_attn else None, - num_reqs=num_reqs, - num_reqs_padded=num_reqs_padded if pad_attn else None, - max_query_len=max_num_scheduled_tokens, - ubatch_slices=ubatch_slices_attn, - logits_indices=logits_indices, - use_spec_decode=use_spec_decode, - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - cascade_attn_prefix_lens=cascade_attn_prefix_lens, - slot_mappings=slot_mappings_by_group, - ) + attn_metadata, cm_by_gid = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded if pad_attn else None, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, + slot_mappings=slot_mappings_by_group, ) ( @@ -3602,13 +3595,13 @@ def execute_model( scheduler_output, logits, spec_decode_metadata, - spec_decode_common_attn_metadata, hidden_states, sample_hidden_states, aux_hidden_states, ec_connector_output, cudagraph_stats, slot_mappings, + cm_by_gid, ) self.kv_connector_output = kv_connector_output return None @@ -3641,13 +3634,13 @@ def sample_tokens( scheduler_output, logits, spec_decode_metadata, - spec_decode_common_attn_metadata, hidden_states, sample_hidden_states, aux_hidden_states, ec_connector_output, cudagraph_stats, slot_mappings, + cm_by_gid, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -3679,7 +3672,7 @@ def sample_tokens( self.input_batch.prev_sampled_token_ids = None def propose_draft_token_ids(sampled_token_ids): - assert spec_decode_common_attn_metadata is not None + assert cm_by_gid is not None with record_function_or_nullcontext("gpu_model_runner: draft"): self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, @@ -3689,16 +3682,19 @@ def propose_draft_token_ids(sampled_token_ids): sample_hidden_states, aux_hidden_states, spec_decode_metadata, - spec_decode_common_attn_metadata, slot_mappings, + cm_by_gid, ) self._copy_draft_token_ids_to_cpu(scheduler_output) spec_config = self.speculative_config propose_drafts_after_bookkeeping = False if spec_config is not None: - input_fits_in_drafter = spec_decode_common_attn_metadata is not None and ( - spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens + # Use any CAM from cm_by_gid for the max_seq_len check (it's a global + # batch property shared across all KV cache groups). + any_cm = next(iter(cm_by_gid.values()), None) if cm_by_gid else None + input_fits_in_drafter = any_cm is not None and ( + any_cm.max_seq_len + self.num_spec_tokens <= self.effective_drafter_max_model_len ) use_gpu_toks = ( @@ -3712,10 +3708,9 @@ def propose_draft_token_ids(sampled_token_ids): if input_fits_in_drafter: propose_draft_token_ids(sampled_token_ids) elif self.valid_sampled_token_count_event is not None: - assert spec_decode_common_attn_metadata is not None next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - spec_decode_common_attn_metadata, + cm_by_gid, sampled_token_ids, self.requests, self.input_batch, @@ -3937,8 +3932,8 @@ def propose_draft_token_ids( sample_hidden_states: torch.Tensor, aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, - common_attn_metadata: CommonAttentionMetadata, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, + cm_by_gid: CommonAttnMetadataByGid, ) -> list[list[int]] | torch.Tensor: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config @@ -4014,7 +4009,7 @@ def propose_draft_token_ids( ) next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - common_attn_metadata, + cm_by_gid, sampled_token_ids, self.requests, self.input_batch, @@ -4041,8 +4036,8 @@ def propose_draft_token_ids( else: if spec_config.disable_padded_drafter_batch: token_indices_to_sample = None - common_attn_metadata, token_indices = self.drafter.prepare_inputs( - common_attn_metadata, + cm_by_gid, token_indices = self.drafter.prepare_inputs( + cm_by_gid, sampled_token_ids, spec_decode_metadata.num_draft_tokens, ) @@ -4057,15 +4052,18 @@ def propose_draft_token_ids( target_hidden_states = hidden_states[token_indices] else: ( - common_attn_metadata, + cm_by_gid, token_indices_to_sample, num_rejected_tokens_gpu, ) = self.drafter.prepare_inputs_padded( - common_attn_metadata, + cm_by_gid, spec_decode_metadata, valid_sampled_tokens_count, ) - total_num_tokens = common_attn_metadata.num_actual_tokens + + # Get any CAM to access the shared num_actual_tokens + any_cm = next(iter(cm_by_gid.values())) + total_num_tokens = any_cm.num_actual_tokens # When padding the batch, token_indices is just a range target_token_ids = self.input_ids.gpu[:total_num_tokens] target_positions = self._get_positions(total_num_tokens) @@ -4092,10 +4090,10 @@ def propose_draft_token_ids( next_token_ids=next_token_ids, last_token_indices=token_indices_to_sample, sampling_metadata=sampling_metadata, - common_attn_metadata=common_attn_metadata, mm_embed_inputs=mm_embed_inputs, num_rejected_tokens_gpu=num_rejected_tokens_gpu, slot_mappings=slot_mappings, + cm_by_gid=cm_by_gid, ) return draft_token_ids @@ -6055,15 +6053,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config, kernel_block_sizes ) - if self.speculative_config and ( - self.speculative_config.use_eagle() - or self.speculative_config.uses_draft_model() - ): - assert isinstance(self.drafter, EagleProposer | DraftModelProposer) - # validate all draft model layers belong to the same kv cache - # group - self.drafter.validate_same_kv_cache_group(kv_cache_config) - if has_kv_transfer_group(): kv_transfer_group = get_kv_transfer_group() if self.cross_layers_kv_cache is not None: diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index f13c75a7ae78..5b7f410dbab3 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -237,3 +237,16 @@ def is_residual_scattered_for_sp( if compile_sizes is None: return False return num_input_tokens in compile_sizes + + +def layer_names_to_kv_cache_group_id( + attn_groups: list[list[AttentionGroup]], + only_prefix: str = "", # Note that str.startswith("") is True +) -> dict[str, int]: + layer_to_kv_cache_gid = {} + for group_id, listof_attn_groups in enumerate(attn_groups): + for attn_group in listof_attn_groups: + for layer_name in attn_group.layer_names: + if layer_name.startswith(only_prefix): + layer_to_kv_cache_gid[layer_name] = group_id + return layer_to_kv_cache_gid From aa6d082ba4ecb73c3a26fbcffb801cbf94155d7c Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 5 Feb 2026 14:35:27 +0100 Subject: [PATCH 2/8] Factor out get_slot_mappings_by_layer() and remove layer_names_to_kv_cache_group_id() Signed-off-by: Tomas Ruiz --- vllm/v1/core/kv_cache_utils.py | 53 +++++++++++++++++++++++++++++- vllm/v1/spec_decode/eagle.py | 4 +-- vllm/v1/worker/gpu_model_runner.py | 24 ++++---------- vllm/v1/worker/utils.py | 13 -------- 4 files changed, 61 insertions(+), 33 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 57eab29189d0..8807750dedd3 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -7,7 +7,12 @@ from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass, replace -from typing import Any, NewType, TypeAlias, overload +from typing import TYPE_CHECKING, Any, NewType, TypeAlias, overload + +import torch + +if TYPE_CHECKING: + from vllm.v1.worker.ubatch_utils import UBatchSlices from vllm import envs from vllm.config import VllmConfig @@ -1649,3 +1654,49 @@ def _get_value_at(self, idx: int) -> BlockHash: BlockHashList = list[BlockHash] | BlockHashListWithBlockSize + + +def get_slot_mappings_by_layer( + kv_cache_config: KVCacheConfig, + slot_mappings_by_gid: dict[int, torch.Tensor], + ubatch_slices: "UBatchSlices | None" = None, +) -> dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]: + """ + Convert slot mappings from group ID indexing to layer name indexing. + + Args: + kv_cache_config: KV cache configuration containing group-to-layer mappings. + slot_mappings_by_gid: Slot mappings keyed by KV cache group ID. + ubatch_slices: Optional ubatch slicing info for DBO (Disaggregated + Block Orchestrator). When provided, returns a list of sliced + mappings per ubatch. + + Returns: + dict[str, torch.Tensor]: Slot mappings keyed by layer name for + ForwardContext, or list[dict[str, torch.Tensor]] when ubatch_slices + is provided. + """ + slot_mappings_by_layer: dict[str, torch.Tensor] = { + layer_name: slot_mappings_by_gid[gid] + for layer_name, gid in kv_cache_group_id_by_layer(kv_cache_config).items() + } + + if ubatch_slices is not None: + result: list[dict[str, torch.Tensor]] = [] + for ubatch in ubatch_slices: + sliced_mappings: dict[str, torch.Tensor] = {} + for layer_name, slot_mapping in slot_mappings_by_layer.items(): + sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice] + result.append(sliced_mappings) + return result + + return slot_mappings_by_layer + + +def kv_cache_group_id_by_layer(kv_cache_config: KVCacheConfig) -> dict[str, int]: + """Return a mapping from layer_name -> KV cache group ID.""" + gid_by_layer: dict[str, int] = {} + for gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + for layer_name in kv_cache_group.layer_names: + gid_by_layer[layer_name] = gid + return gid_by_layer diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a919f9c7a98b..861f7d9817e0 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -41,6 +41,7 @@ TreeAttentionMetadataBuilder, ) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata +from vllm.v1.core.kv_cache_utils import kv_cache_group_id_by_layer from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS @@ -52,7 +53,6 @@ from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm.v1.worker.utils import layer_names_to_kv_cache_group_id logger = init_logger(__name__) @@ -238,7 +238,7 @@ def __init__( @cached_property def layer_names_to_kv_cache_gid(self) -> dict[str, int]: - return layer_names_to_kv_cache_group_id(self.runner.attn_groups) + return kv_cache_group_id_by_layer(self.runner.kv_cache_config) def pick_first_layer_common_attn_metadata( self, cm_by_gid: CommonAttnMetadataByGid diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c78b157ba123..87d4333121d5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -119,6 +119,7 @@ get_dcp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills, ) +from vllm.v1.core.kv_cache_utils import get_slot_mappings_by_layer from vllm.v1.core.sched.output import NewRequestData from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import ( @@ -3247,7 +3248,9 @@ def _get_slot_mappings( Returns: A tuple of: - slot_mappings_by_gid: dict[int, torch.Tensor] for attention metadata - - slot_mappings_by_layer: dict[str, torch.Tensor] or list for ForwardContext + - slot_mappings_by_layer: dict[str, torch.Tensor] or list for + ForwardContext + Or None if KV cache config is not available. """ if not ( hasattr(self, "kv_cache_config") @@ -3281,22 +3284,9 @@ def _get_slot_mapping(kv_cache_gid: int): gid: _get_slot_mapping(gid) for gid, _ in enumerate(self.kv_cache_config.kv_cache_groups) } - - slot_mappings_by_layer: dict[str, torch.Tensor] = {} - for gid, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups): - slot_mapping = slot_mappings_by_gid[gid] - for layer_name in kv_cache_group.layer_names: - slot_mappings_by_layer[layer_name] = slot_mapping - - if ubatch_slices is not None: - result: list[dict[str, torch.Tensor]] = [] - for ubatch in ubatch_slices: - sliced_mappings: dict[str, torch.Tensor] = {} - for layer_name, slot_mapping in slot_mappings_by_layer.items(): - sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice] - result.append(sliced_mappings) - return slot_mappings_by_gid, result - + slot_mappings_by_layer = get_slot_mappings_by_layer( + self.kv_cache_config, slot_mappings_by_gid, ubatch_slices + ) return slot_mappings_by_gid, slot_mappings_by_layer @torch.inference_mode() diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 5b7f410dbab3..f13c75a7ae78 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -237,16 +237,3 @@ def is_residual_scattered_for_sp( if compile_sizes is None: return False return num_input_tokens in compile_sizes - - -def layer_names_to_kv_cache_group_id( - attn_groups: list[list[AttentionGroup]], - only_prefix: str = "", # Note that str.startswith("") is True -) -> dict[str, int]: - layer_to_kv_cache_gid = {} - for group_id, listof_attn_groups in enumerate(attn_groups): - for attn_group in listof_attn_groups: - for layer_name in attn_group.layer_names: - if layer_name.startswith(only_prefix): - layer_to_kv_cache_gid[layer_name] = group_id - return layer_to_kv_cache_gid From 0f724bb4c91f58c51d06e6e4be562c528e949f58 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 5 Feb 2026 15:14:34 +0100 Subject: [PATCH 3/8] Reuse get_slot_mappings_by_layer() in EAGLE-3 Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 60 +++++++++++++++++------------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 861f7d9817e0..e2d30ff69c49 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast from collections import defaultdict +from collections.abc import Iterator from dataclasses import replace from functools import cached_property from importlib.util import find_spec @@ -41,7 +42,10 @@ TreeAttentionMetadataBuilder, ) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata -from vllm.v1.core.kv_cache_utils import kv_cache_group_id_by_layer +from vllm.v1.core.kv_cache_utils import ( + get_slot_mappings_by_layer, + kv_cache_group_id_by_layer, +) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS @@ -280,33 +284,36 @@ def new_slot_mapping_buffer(self) -> torch.Tensor: ) def _get_slot_mapping( - self, - num_tokens: int, - cm_by_gid: CommonAttnMetadataByGid, + self, cm_by_gid: CommonAttnMetadataByGid ) -> dict[str, torch.Tensor]: """Return slot_mapping dict for EAGLE/draft layers during inference. Copies per-group slot mappings from CommonAttentionMetadata into per-gid buffers and returns views into those buffers. """ - # COPY + slot_mapping_by_gid = dict(self._set_and_get_slot_mapping_bufs(cm_by_gid)) + return get_slot_mappings_by_layer( + self.runner.kv_cache_config, + slot_mapping_by_gid, + ubatch_slices=None, + ) + + def _set_and_get_slot_mapping_bufs( + self, cm_by_gid: CommonAttnMetadataByGid + ) -> Iterator[tuple[int, torch.Tensor]]: + """ + Set the slot mapping buffers for each KV cache group. + Then yield the buffer for each KV cache group ID. + """ for gid, cm in cm_by_gid.items(): - buf = self._slot_mapping_buffer_by_gid[gid] src = cm.slot_mapping - n = min(num_tokens, src.shape[0]) + src_len = src.shape[0] + buf = self._slot_mapping_buffer_by_gid[gid] + n = min(cm.num_actual_tokens, src_len) buf[:n].copy_(src[:n]) - if n < num_tokens: - buf[n:num_tokens].fill_(PADDING_SLOT_ID) - - # READ - layer_to_buffer: dict[str, torch.Tensor] = {} - layer_names = [*self.attn_layer_names, *self.indexer_layer_names] - for layer_name in layer_names: - gid = self.layer_names_to_kv_cache_gid[layer_name] - buf = self._slot_mapping_buffer_by_gid[gid][:num_tokens] - layer_to_buffer[layer_name] = buf - - return layer_to_buffer + if n < src_len: + buf[n:src_len].fill_(PADDING_SLOT_ID) + yield gid, buf[:src_len] def _get_slot_mapping_dummy_run(self, num_tokens: int) -> dict[str, torch.Tensor]: """Return slot_mapping dict for EAGLE/draft layers during dummy run. @@ -449,10 +456,7 @@ def propose( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, - slot_mapping=self._get_slot_mapping( - num_tokens=num_input_tokens, - cm_by_gid=cm_by_gid, - ), + slot_mapping=self._get_slot_mapping(cm_by_gid=cm_by_gid), ): ret_hidden_states = self.model(**model_kwargs) if not self.model_returns_tuple(): @@ -667,10 +671,7 @@ def propose( num_tokens=input_batch_size, num_tokens_across_dp=batch_size_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, - slot_mapping=self._get_slot_mapping( - num_tokens=input_batch_size, - cm_by_gid=cm_by_gid, - ), + slot_mapping=self._get_slot_mapping(cm_by_gid), ): ret_hidden_states = self.model(**model_kwargs) if not self.model_returns_tuple(): @@ -1021,10 +1022,7 @@ def propose_tree( self.vllm_config, num_tokens=num_input_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, - slot_mapping=self._get_slot_mapping( - num_tokens=num_input_tokens, - cm_by_gid=cm_by_gid, - ), + slot_mapping=self._get_slot_mapping(cm_by_gid), ): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], From 1caf68b32baf6bf271271c1bbbb9c4b7d51e20f1 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 5 Feb 2026 15:41:19 +0100 Subject: [PATCH 4/8] Reduce code changes Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index c7148e1eadea..b927e80660d0 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -134,7 +134,6 @@ def set_inputs_first_pass( is_rejected_token_mask=is_rejected_tok, block_size=self._get_metadata_builder(gid).kv_cache_spec.block_size, max_model_len=self.max_model_len, - block_table_tensor=cm.block_table_tensor, ) new_cm = cm.replace(slot_mapping=slot_mapping) new_cm = extend_all_queries_by_1(new_cm, arange=self.arange) @@ -213,9 +212,8 @@ def compute_new_slot_mapping( is_rejected_token_mask: torch.Tensor, block_size: int, max_model_len: int, - block_table_tensor: torch.Tensor, ): - batch_size, n_blocks_per_req = block_table_tensor.shape + batch_size, n_blocks_per_req = cad.block_table_tensor.shape req_indices = torch.arange(batch_size, device=cad.query_start_loc.device) req_indices = torch.repeat_interleave( req_indices, cad.naive_query_lens() + 1, output_size=len(new_positions) @@ -226,7 +224,7 @@ def compute_new_slot_mapping( block_table_indices = ( req_indices * n_blocks_per_req + clamped_positions // block_size ) - block_nums = block_table_tensor.view(-1)[block_table_indices] + block_nums = cad.block_table_tensor.view(-1)[block_table_indices] block_offsets = clamped_positions % block_size new_slot_mapping = block_nums * block_size + block_offsets # Mask out the position ids that exceed the max model length. From 523b54d4d6988f55ed27f1b1ddf6a76a95123021 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 5 Feb 2026 15:44:07 +0100 Subject: [PATCH 5/8] Fix mypy issue Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index e2d30ff69c49..8cf8e60c5d59 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -285,7 +285,7 @@ def new_slot_mapping_buffer(self) -> torch.Tensor: def _get_slot_mapping( self, cm_by_gid: CommonAttnMetadataByGid - ) -> dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]: """Return slot_mapping dict for EAGLE/draft layers during inference. Copies per-group slot mappings from CommonAttentionMetadata into From 87cc770f8e0c4e06409e9b8a814e7f78b73b6878 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 5 Feb 2026 16:08:07 +0100 Subject: [PATCH 6/8] Fix test_propose() Signed-off-by: Tomas Ruiz --- tests/v1/spec_decode/test_eagle.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 52f7be518240..e992343beb99 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -596,6 +596,7 @@ def create_deterministic_logits(token_ids): ) kv_cache_group_id = 0 common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata} + proposer.layer_names_to_kv_cache_gid = {"layer.0": kv_cache_group_id} result = proposer.propose( target_token_ids=target_token_ids, @@ -733,6 +734,8 @@ def create_deterministic_logits(token_ids, k: int): proposer._get_attention_metadata_builder = mock.MagicMock( return_value=attn_metadata_builder ) + kv_cache_group_id = 0 + proposer.layer_names_to_kv_cache_gid = {"layer.0": kv_cache_group_id} # Setup inputs for the proposer. target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device) @@ -752,7 +755,6 @@ def create_deterministic_logits(token_ids, k: int): block_size=16, device=device, ) - kv_cache_group_id = 0 common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata} sampling_metadata = mock.MagicMock() From d4f8b6783aeef4bd2d46c10c8d8672f38b841175 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 5 Feb 2026 16:25:41 +0100 Subject: [PATCH 7/8] Fix for docstrings Signed-off-by: Tomas Ruiz --- vllm/v1/worker/gpu_model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 87d4333121d5..01eebf5957d2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3248,8 +3248,7 @@ def _get_slot_mappings( Returns: A tuple of: - slot_mappings_by_gid: dict[int, torch.Tensor] for attention metadata - - slot_mappings_by_layer: dict[str, torch.Tensor] or list for - ForwardContext + - slot_mappings_by_layer: dict[str, torch.Tensor] or list for ForwardContext Or None if KV cache config is not available. """ if not ( From 9e3be0e454b180f148aedc0980fef658d23ab915 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 25 Feb 2026 12:40:56 +0100 Subject: [PATCH 8/8] Fix test_set_inputs_first_pass tests for cm_by_gid API Adapt the three test_set_inputs_first_pass_* tests (from main) to use the multi-gid API: wrap CommonAttentionMetadata in cm_by_gid dict, set up layer_names_to_kv_cache_gid, and mock _get_metadata_builder for per-group slot mapping in the draft model/parallel drafting paths. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Tomas Ruiz --- tests/v1/spec_decode/test_eagle.py | 82 +++++++++++++++++++++--------- 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 4cd59de020de..5217fc3e19b1 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -383,6 +383,9 @@ def test_set_inputs_first_pass_default_eagle(): num_speculative_tokens = 3 proposer = _create_proposer("eagle", num_speculative_tokens) + kv_cache_group_id = 0 + proposer.attn_layer_names = ["layer0"] + proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id} # Setup batch with 3 requests batch_spec = BatchSpec( @@ -395,6 +398,7 @@ def test_set_inputs_first_pass_default_eagle(): block_size=16, device=device, ) + cm_by_gid = {kv_cache_group_id: common_attn_metadata} # Input tensors # Request 0: tokens [10, 11, 12] at positions [7, 8, 9] @@ -411,14 +415,16 @@ def test_set_inputs_first_pass_default_eagle(): ) next_token_ids = torch.tensor([100, 200, 300], dtype=torch.int32, device=device) - num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass( - target_token_ids=target_token_ids, - next_token_ids=next_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - token_indices_to_sample=None, - cad=common_attn_metadata, - num_rejected_tokens_gpu=None, + num_tokens, token_indices_to_sample, output_cm_by_gid = ( + proposer.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=None, + cm_by_gid=cm_by_gid, + num_rejected_tokens_gpu=None, + ) ) assert num_tokens == 9 # Total tokens unchanged @@ -428,7 +434,7 @@ def test_set_inputs_first_pass_default_eagle(): ) assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) - assert output_cad is common_attn_metadata + assert output_cm_by_gid is cm_by_gid # Verify input_ids are rotated and next_tokens inserted # Original: [10, 11, 12, 20, 21, 30, 31, 32, 33] @@ -488,6 +494,9 @@ def test_set_inputs_first_pass_draft_model(): # Create a proposer configured as a draft model (pass_hidden_states=False) # We need to mock this since _create_proposer defaults to EAGLE proposer = _create_proposer("draft_model", num_speculative_tokens) + kv_cache_group_id = 0 + proposer.attn_layer_names = ["layer0"] + proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id} proposer.parallel_drafting_token_id = 0 proposer.is_rejected_token_mask = torch.zeros( @@ -504,6 +513,12 @@ def test_set_inputs_first_pass_draft_model(): mock_builder.kv_cache_spec = mock_kv_cache_spec proposer.attn_metadata_builder = mock_builder + # Also mock _get_metadata_builder for per-group slot mapping + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = mock_builder + proposer.runner = mock.MagicMock() + proposer.runner.attn_groups = [[mock_attn_group]] + # Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2 batch_spec = BatchSpec( seq_lens=[3, 2], @@ -516,6 +531,7 @@ def test_set_inputs_first_pass_draft_model(): device=device, arange_block_indices=True, # Use predictable block indices ) + cm_by_gid = {kv_cache_group_id: common_attn_metadata} # Input tensors target_token_ids = torch.tensor( @@ -529,14 +545,16 @@ def test_set_inputs_first_pass_draft_model(): num_rejected_tokens_gpu = torch.tensor([1, 0], dtype=torch.int32, device=device) - num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass( - target_token_ids=target_token_ids, - next_token_ids=next_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - token_indices_to_sample=None, - cad=common_attn_metadata, - num_rejected_tokens_gpu=num_rejected_tokens_gpu, + num_tokens, token_indices_to_sample, output_cm_by_gid = ( + proposer.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=None, + cm_by_gid=cm_by_gid, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + ) ) assert proposer.net_num_new_slots_per_request == 1 @@ -584,6 +602,7 @@ def test_set_inputs_first_pass_draft_model(): # Verify the new CAD has updated query_start_loc # Original: [0, 3, 5] -> New: [0, 4, 7] (each request gains 1 slot) + output_cad = output_cm_by_gid[kv_cache_group_id] expected_query_start_loc = torch.tensor([0, 4, 7], dtype=torch.int32, device=device) assert torch.equal(output_cad.query_start_loc, expected_query_start_loc) @@ -624,6 +643,9 @@ def test_set_inputs_first_pass_parallel_drafting(): block_size = 16 proposer = _create_proposer("eagle", num_speculative_tokens, parallel_drafting=True) + kv_cache_group_id = 0 + proposer.attn_layer_names = ["layer0"] + proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id} # Override to simulate parallel drafting behavior proposer.parallel_drafting_token_id = -2 @@ -644,6 +666,12 @@ def test_set_inputs_first_pass_parallel_drafting(): mock_builder.kv_cache_spec = mock_kv_cache_spec proposer.attn_metadata_builder = mock_builder + # Also mock _get_metadata_builder for per-group slot mapping + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = mock_builder + proposer.runner = mock.MagicMock() + proposer.runner.attn_groups = [[mock_attn_group]] + # Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid) batch_spec = BatchSpec( seq_lens=[9, 14], @@ -656,6 +684,7 @@ def test_set_inputs_first_pass_parallel_drafting(): device=device, arange_block_indices=True, ) + cm_by_gid = {kv_cache_group_id: common_attn_metadata} # Input tensors target_token_ids = torch.tensor( @@ -671,14 +700,16 @@ def test_set_inputs_first_pass_parallel_drafting(): num_rejected_tokens_gpu = torch.tensor([1, 0], dtype=torch.int32, device=device) - num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass( - target_token_ids=target_token_ids, - next_token_ids=next_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - token_indices_to_sample=None, - cad=common_attn_metadata, - num_rejected_tokens_gpu=num_rejected_tokens_gpu, + num_tokens, token_indices_to_sample, output_cm_by_gid = ( + proposer.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=None, + cm_by_gid=cm_by_gid, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + ) ) # total_output_tokens = total_input_tokens + net_num_new_slots * batch_size @@ -730,6 +761,7 @@ def test_set_inputs_first_pass_parallel_drafting(): # Verify the new CAD has updated query_start_loc # Original query_lens: [4, 4] -> Output: [6, 6] + output_cad = output_cm_by_gid[kv_cache_group_id] expected_query_start_loc = torch.tensor( [0, 6, 12], dtype=torch.int32, device=device )