diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index a141e9da08a1..5992a9413885 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -613,6 +613,15 @@ def test_mtp_correctness( cleanup_dist_env_and_memory() +def some_high_acceptance_metrics() -> dict: + return { + "sampling_config": greedy_sampling(), + "num_speculative_tokens": 3, + "expected_acceptance_len": 2.8 + 1, + "expected_acceptance_rate": 0.90, + } + + @dataclass class ArgsTest: target_model: str @@ -630,6 +639,8 @@ class ArgsTest: gpu_memory_utilization: float = 0.5 dataset: str = "test_prompts" num_prompts: int = 100 + # Some settings only get 100% acceptance_rate with VLLM_BATCH_INVARIANT=1 + use_batch_invariance: bool = False cases = [ @@ -651,12 +662,36 @@ class ArgsTest: expected_acceptance_len=2.8 + 1, expected_acceptance_rate=0.9, ), + # Multiple KV Cache groups + ArgsTest( + target_model="google/gemma-3-270m-it", + draft_model="google/gemma-3-270m-it", + sampling_config=greedy_sampling(), + num_speculative_tokens=3, + expected_acceptance_len=4, + expected_acceptance_rate=1, + # Without batch invariance, acceptance rate is ~86% + use_batch_invariance=True, + ), + # GPT-OSS MoE models with different target/draft sizes + # Tests MoE layer resolution in speculative decoding. + ArgsTest( + target_model="openai/gpt-oss-120b", + draft_model="openai/gpt-oss-20b", + # Leave some headroom for CUDA graph capture. + gpu_memory_utilization=0.85, + **some_high_acceptance_metrics(), + ), ] @pytest.mark.parametrize("args", cases) @pytest.mark.parametrize("enforce_eager", [True, False]) -def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool): +def test_draft_model_correctness( + args: ArgsTest, enforce_eager: bool, monkeypatch: pytest.MonkeyPatch +): + if args.use_batch_invariance: + monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1") args.enforce_eager = enforce_eager assert_draft_model_correctness(args) @@ -772,6 +807,8 @@ def test_draft_model_engine_args_rejects_invalid_tp_argname(): def assert_draft_model_correctness(args: ArgsTest): """Compare the outputs using and not using speculative decoding. In the greedy decoding case, the outputs must match EXACTLY.""" + attention_config = {"backend": "FLASH_ATTN"} if args.use_batch_invariance else None + test_prompts: list[Messages] = get_messages( dataset=args.dataset, n=args.num_prompts ) @@ -793,6 +830,7 @@ def assert_draft_model_correctness(args: ArgsTest): tensor_parallel_size=args.target_tensor_parallel_size, enforce_eager=args.enforce_eager, disable_log_stats=False, # enables get_metrics() + attention_config=attention_config, ) # we don't check the outputs, only check the metrics spec_llm.chat(test_prompts, args.sampling_config) @@ -824,15 +862,6 @@ def get_messages(dataset: str, n: int) -> list[Messages]: raise NotImplementedError(f"Dataset '{dataset}' not implemented") -def some_high_acceptance_metrics() -> dict: - return { - "sampling_config": greedy_sampling(), - "num_speculative_tokens": 3, - "expected_acceptance_len": 2.8 + 1, - "expected_acceptance_rate": 0.90, - } - - def compute_acceptance_rate(metrics: list[Metric]) -> float: name2metric = {metric.name: metric for metric in metrics} n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 65e97b7ad5b0..5217fc3e19b1 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -166,6 +166,11 @@ def test_prepare_next_token_ids(): block_size=16, device=device, ) + kv_cache_group_id = 0 + common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata} + # Mock connection of layer names to kv_cache_group_ids + proposer.attn_layer_names = ["layer0"] + proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id} expected_valid_sampled_tokens_count = torch.tensor( [2, 5, 0, 0], dtype=torch.int32, device=device @@ -173,7 +178,7 @@ def test_prepare_next_token_ids(): next_token_ids_from_padded, valid_sampled_tokens_count = ( proposer.prepare_next_token_ids_padded( - common_attn_metadata, + common_attn_metadata_by_gid, sampled_token_ids_tensor, mock_requests, mock_input_batch, @@ -266,10 +271,17 @@ def test_prepare_inputs(): ) proposer = _create_proposer("eagle", 1) - updated_metadata, token_indices = proposer.prepare_inputs( - common_attn_metadata, sampled_token_ids, num_draft_tokens + # Mock connection of layer names to kv_cache_group_ids + kv_cache_group_id = 0 + common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata} + proposer.attn_layer_names = ["layer0"] + proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id} + + updated_cm_by_gid, token_indices = proposer.prepare_inputs( + common_attn_metadata_by_gid, sampled_token_ids, num_draft_tokens ) + updated_metadata = updated_cm_by_gid[kv_cache_group_id] assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens) assert token_indices.shape[0] == expected_cu_num_tokens[-1].item() assert torch.equal(token_indices, expected_token_indices) @@ -324,9 +336,17 @@ def test_prepare_inputs_padded(): proposer = _create_proposer("eagle", num_speculative_tokens) - output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = ( + # Mock connection of layer names to kv_cache_group_ids + kv_cache_group_id = 0 + common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata} + proposer.attn_layer_names = ["layer0"] + proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id} + + output_cm_by_gid, token_indices_to_sample, num_rejected_tokens_gpu = ( proposer.prepare_inputs_padded( - common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count + common_attn_metadata_by_gid, + spec_decode_metadata, + valid_sampled_tokens_count, ) ) @@ -334,6 +354,7 @@ def test_prepare_inputs_padded(): expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device) assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected) + output_metadata = output_cm_by_gid[kv_cache_group_id] assert output_metadata.max_query_len == 3 assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc) assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) @@ -362,6 +383,9 @@ def test_set_inputs_first_pass_default_eagle(): num_speculative_tokens = 3 proposer = _create_proposer("eagle", num_speculative_tokens) + kv_cache_group_id = 0 + proposer.attn_layer_names = ["layer0"] + proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id} # Setup batch with 3 requests batch_spec = BatchSpec( @@ -374,6 +398,7 @@ def test_set_inputs_first_pass_default_eagle(): block_size=16, device=device, ) + cm_by_gid = {kv_cache_group_id: common_attn_metadata} # Input tensors # Request 0: tokens [10, 11, 12] at positions [7, 8, 9] @@ -390,14 +415,16 @@ def test_set_inputs_first_pass_default_eagle(): ) next_token_ids = torch.tensor([100, 200, 300], dtype=torch.int32, device=device) - num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass( - target_token_ids=target_token_ids, - next_token_ids=next_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - token_indices_to_sample=None, - cad=common_attn_metadata, - num_rejected_tokens_gpu=None, + num_tokens, token_indices_to_sample, output_cm_by_gid = ( + proposer.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=None, + cm_by_gid=cm_by_gid, + num_rejected_tokens_gpu=None, + ) ) assert num_tokens == 9 # Total tokens unchanged @@ -407,7 +434,7 @@ def test_set_inputs_first_pass_default_eagle(): ) assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) - assert output_cad is common_attn_metadata + assert output_cm_by_gid is cm_by_gid # Verify input_ids are rotated and next_tokens inserted # Original: [10, 11, 12, 20, 21, 30, 31, 32, 33] @@ -467,6 +494,9 @@ def test_set_inputs_first_pass_draft_model(): # Create a proposer configured as a draft model (pass_hidden_states=False) # We need to mock this since _create_proposer defaults to EAGLE proposer = _create_proposer("draft_model", num_speculative_tokens) + kv_cache_group_id = 0 + proposer.attn_layer_names = ["layer0"] + proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id} proposer.parallel_drafting_token_id = 0 proposer.is_rejected_token_mask = torch.zeros( @@ -483,6 +513,12 @@ def test_set_inputs_first_pass_draft_model(): mock_builder.kv_cache_spec = mock_kv_cache_spec proposer.attn_metadata_builder = mock_builder + # Also mock _get_metadata_builder for per-group slot mapping + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = mock_builder + proposer.runner = mock.MagicMock() + proposer.runner.attn_groups = [[mock_attn_group]] + # Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2 batch_spec = BatchSpec( seq_lens=[3, 2], @@ -495,6 +531,7 @@ def test_set_inputs_first_pass_draft_model(): device=device, arange_block_indices=True, # Use predictable block indices ) + cm_by_gid = {kv_cache_group_id: common_attn_metadata} # Input tensors target_token_ids = torch.tensor( @@ -508,14 +545,16 @@ def test_set_inputs_first_pass_draft_model(): num_rejected_tokens_gpu = torch.tensor([1, 0], dtype=torch.int32, device=device) - num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass( - target_token_ids=target_token_ids, - next_token_ids=next_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - token_indices_to_sample=None, - cad=common_attn_metadata, - num_rejected_tokens_gpu=num_rejected_tokens_gpu, + num_tokens, token_indices_to_sample, output_cm_by_gid = ( + proposer.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=None, + cm_by_gid=cm_by_gid, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + ) ) assert proposer.net_num_new_slots_per_request == 1 @@ -563,6 +602,7 @@ def test_set_inputs_first_pass_draft_model(): # Verify the new CAD has updated query_start_loc # Original: [0, 3, 5] -> New: [0, 4, 7] (each request gains 1 slot) + output_cad = output_cm_by_gid[kv_cache_group_id] expected_query_start_loc = torch.tensor([0, 4, 7], dtype=torch.int32, device=device) assert torch.equal(output_cad.query_start_loc, expected_query_start_loc) @@ -603,6 +643,9 @@ def test_set_inputs_first_pass_parallel_drafting(): block_size = 16 proposer = _create_proposer("eagle", num_speculative_tokens, parallel_drafting=True) + kv_cache_group_id = 0 + proposer.attn_layer_names = ["layer0"] + proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id} # Override to simulate parallel drafting behavior proposer.parallel_drafting_token_id = -2 @@ -623,6 +666,12 @@ def test_set_inputs_first_pass_parallel_drafting(): mock_builder.kv_cache_spec = mock_kv_cache_spec proposer.attn_metadata_builder = mock_builder + # Also mock _get_metadata_builder for per-group slot mapping + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = mock_builder + proposer.runner = mock.MagicMock() + proposer.runner.attn_groups = [[mock_attn_group]] + # Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid) batch_spec = BatchSpec( seq_lens=[9, 14], @@ -635,6 +684,7 @@ def test_set_inputs_first_pass_parallel_drafting(): device=device, arange_block_indices=True, ) + cm_by_gid = {kv_cache_group_id: common_attn_metadata} # Input tensors target_token_ids = torch.tensor( @@ -650,14 +700,16 @@ def test_set_inputs_first_pass_parallel_drafting(): num_rejected_tokens_gpu = torch.tensor([1, 0], dtype=torch.int32, device=device) - num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass( - target_token_ids=target_token_ids, - next_token_ids=next_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - token_indices_to_sample=None, - cad=common_attn_metadata, - num_rejected_tokens_gpu=num_rejected_tokens_gpu, + num_tokens, token_indices_to_sample, output_cm_by_gid = ( + proposer.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=None, + cm_by_gid=cm_by_gid, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + ) ) # total_output_tokens = total_input_tokens + net_num_new_slots * batch_size @@ -709,6 +761,7 @@ def test_set_inputs_first_pass_parallel_drafting(): # Verify the new CAD has updated query_start_loc # Original query_lens: [4, 4] -> Output: [6, 6] + output_cad = output_cm_by_gid[kv_cache_group_id] expected_query_start_loc = torch.tensor( [0, 6, 12], dtype=torch.int32, device=device ) @@ -968,13 +1021,16 @@ def create_deterministic_logits(token_ids): # Mock runner for attention metadata building proposer.runner = mock.MagicMock() - proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][ - 0 - ].get_metadata_builder.return_value = attn_metadata_builder + attn_group = mock.MagicMock() + attn_group.get_metadata_builder.return_value = attn_metadata_builder + attn_group.layer_names = proposer.attn_layer_names + proposer.runner.attn_groups = [[attn_group]] proposer._get_attention_metadata_builder = mock.MagicMock( return_value=attn_metadata_builder ) + kv_cache_group_id = 0 + common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata} + proposer.layer_names_to_kv_cache_gid = {"layer.0": kv_cache_group_id} result = proposer.propose( target_token_ids=target_token_ids, @@ -982,8 +1038,8 @@ def create_deterministic_logits(token_ids): target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, token_indices_to_sample=None, - common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, + cm_by_gid=common_attn_metadata_by_gid, ) assert result.shape == (batch_size, num_speculative_tokens) @@ -1104,14 +1160,16 @@ def create_deterministic_logits(token_ids, k: int): # Mock runner for attention metadata building. proposer.runner = mock.MagicMock() - proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder] - proposer.runner.attn_groups[0][ - 0 - ].get_metadata_builder.return_value = attn_metadata_builder + attn_group = mock.MagicMock() + attn_group.layer_names = proposer.attn_layer_names + attn_group.metadata_builders = [attn_metadata_builder] + attn_group.get_metadata_builder.return_value = attn_metadata_builder + proposer.runner.attn_groups = [[attn_group]] proposer._get_attention_metadata_builder = mock.MagicMock( return_value=attn_metadata_builder ) + kv_cache_group_id = 0 + proposer.layer_names_to_kv_cache_gid = {"layer.0": kv_cache_group_id} # Setup inputs for the proposer. target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device) @@ -1131,6 +1189,7 @@ def create_deterministic_logits(token_ids, k: int): block_size=16, device=device, ) + common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata} sampling_metadata = mock.MagicMock() # Propose draft tokens. @@ -1140,8 +1199,8 @@ def create_deterministic_logits(token_ids, k: int): target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, token_indices_to_sample=None, - common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, + cm_by_gid=common_attn_metadata_by_gid, ) assert result.shape == (batch_size, num_speculative_tokens) diff --git a/tests/v1/worker/test_utils.py b/tests/v1/worker/test_utils.py index 76f9a8f90f70..d861c7c3c598 100644 --- a/tests/v1/worker/test_utils.py +++ b/tests/v1/worker/test_utils.py @@ -3,6 +3,7 @@ import torch +from vllm.v1.core.kv_cache_utils import DRAFT_MODEL_PREFIX from vllm.v1.worker.utils import bind_kv_cache @@ -63,8 +64,8 @@ def test_bind_kv_cache_draft_model(default_vllm_config): layer_names = [ "model.layers.0.attn", "model.layers.1.attn", - "draft_model.layers.0.attn", - "draft_model.layers.1.attn", + f"{DRAFT_MODEL_PREFIX}.layers.0.attn", + f"{DRAFT_MODEL_PREFIX}.layers.1.attn", ] ctx = { layer_name: Attention(32, 128, 0.1, prefix=layer_name) @@ -76,17 +77,13 @@ def test_bind_kv_cache_draft_model(default_vllm_config): assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"] assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"] - assert ( - ctx["draft_model.layers.0.attn"].kv_cache[0] - is kv_cache["draft_model.layers.0.attn"] - ) - assert ( - ctx["draft_model.layers.1.attn"].kv_cache[0] - is kv_cache["draft_model.layers.1.attn"] - ) + draft_layer_0 = f"{DRAFT_MODEL_PREFIX}.layers.0.attn" + draft_layer_1 = f"{DRAFT_MODEL_PREFIX}.layers.1.attn" + assert ctx[draft_layer_0].kv_cache[0] is kv_cache[draft_layer_0] + assert ctx[draft_layer_1].kv_cache[0] is kv_cache[draft_layer_1] # caches are ordered by layer_index, interleaving target and draft model assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"] - assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"] + assert runner_kv_caches[1] is kv_cache[draft_layer_0] assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"] - assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"] + assert runner_kv_caches[3] is kv_cache[draft_layer_1] diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index f31e2635a0f1..e65489e7b6b7 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -415,6 +415,12 @@ def unpadded( ) +# Mapping from KV cache group ID to its CommonAttentionMetadata. +# Each KV cache group may have different block_table_tensor and slot_mapping, +# while sharing other fields like query_start_loc and seq_lens. +CommonAttnMetadataByGid = dict[int, CommonAttentionMetadata] + + M = TypeVar("M") diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 2f59e71a13df..694b562abd3d 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -8,7 +8,12 @@ from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass, replace from functools import partial -from typing import Any, NewType, TypeAlias, overload +from typing import TYPE_CHECKING, Any, NewType, TypeAlias, overload + +import torch + +if TYPE_CHECKING: + from vllm.v1.worker.ubatch_utils import UBatchSlices from vllm import envs from vllm.config import VllmConfig @@ -75,6 +80,10 @@ def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash: logger = init_logger(__name__) +# Prefix for draft model layer names in the KV cache (e.g. "draft_model.layers.0.attn"). +# Must match the prefix passed to get_model() when loading the draft model. +DRAFT_MODEL_PREFIX = "draft_model" + # The hash seed for the first block of any prefix block sequence. # # We use a random value to avoid hash collisions or PYTHONHASHSEED environment @@ -1012,12 +1021,15 @@ def _get_kv_cache_groups_uniform_page_size( Returns: The generated KVCacheGroupSpecs """ - # Group all layers by kv_cache_spec. + # Group all layers by kv_cache_spec AND whether it's a draft or target layer. # E.g., 2 full attention layers and 3 sliding window attention layers, # -> (full.0, full.1), (sw.0, sw.1, sw.2). - same_type_layers: dict[KVCacheSpec, list[str]] = defaultdict(list) + # Draft and target layers are kept in separate groups to prevent KV cache + # cross-contamination during speculative decoding. + same_type_layers: dict[tuple[KVCacheSpec, bool], list[str]] = defaultdict(list) for layer_name, layer_spec in kv_cache_spec.items(): - same_type_layers[layer_spec].append(layer_name) + is_draft_layer = layer_name.startswith(DRAFT_MODEL_PREFIX) + same_type_layers[(layer_spec, is_draft_layer)].append(layer_name) # Split each group into smaller groups, to make the number of layers in each # group identical. Add padding to the last group of each type if necessary. @@ -1678,3 +1690,49 @@ def _get_value_at(self, idx: int) -> BlockHash: BlockHashList = list[BlockHash] | BlockHashListWithBlockSize + + +def get_slot_mappings_by_layer( + kv_cache_config: KVCacheConfig, + slot_mappings_by_gid: dict[int, torch.Tensor], + ubatch_slices: "UBatchSlices | None" = None, +) -> dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]: + """ + Convert slot mappings from group ID indexing to layer name indexing. + + Args: + kv_cache_config: KV cache configuration containing group-to-layer mappings. + slot_mappings_by_gid: Slot mappings keyed by KV cache group ID. + ubatch_slices: Optional ubatch slicing info for DBO (Disaggregated + Block Orchestrator). When provided, returns a list of sliced + mappings per ubatch. + + Returns: + dict[str, torch.Tensor]: Slot mappings keyed by layer name for + ForwardContext, or list[dict[str, torch.Tensor]] when ubatch_slices + is provided. + """ + slot_mappings_by_layer: dict[str, torch.Tensor] = { + layer_name: slot_mappings_by_gid[gid] + for layer_name, gid in kv_cache_group_id_by_layer(kv_cache_config).items() + } + + if ubatch_slices is not None: + result: list[dict[str, torch.Tensor]] = [] + for ubatch in ubatch_slices: + sliced_mappings: dict[str, torch.Tensor] = {} + for layer_name, slot_mapping in slot_mappings_by_layer.items(): + sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice] + result.append(sliced_mappings) + return result + + return slot_mappings_by_layer + + +def kv_cache_group_id_by_layer(kv_cache_config: KVCacheConfig) -> dict[str, int]: + """Return a mapping from layer_name -> KV cache group ID.""" + gid_by_layer: dict[str, int] = {} + for gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + for layer_name in kv_cache_group.layer_names: + gid_by_layer[layer_name] = gid + return gid_by_layer diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a6e7995bc7cb..8a099a26827d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast +from collections import defaultdict +from collections.abc import Iterator from dataclasses import replace +from functools import cached_property from importlib.util import find_spec from typing import cast @@ -28,8 +31,10 @@ from vllm.triton_utils import triton from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backend import ( + AttentionMetadata, AttentionMetadataBuilder, CommonAttentionMetadata, + CommonAttnMetadataByGid, ) from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.tree_attn import ( @@ -37,8 +42,11 @@ TreeAttentionMetadataBuilder, ) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata +from vllm.v1.core.kv_cache_utils import ( + get_slot_mappings_by_layer, + kv_cache_group_id_by_layer, +) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher -from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -208,8 +216,9 @@ def __init__( with_numpy=True, ) - self._slot_mapping_buffer = torch.zeros( - self.max_num_tokens, dtype=torch.int64, device=device + # Per-gid slot-mapping buffers (lazily created via defaultdict). + self._slot_mapping_buffer_by_gid: defaultdict[int, torch.Tensor] = defaultdict( + self.new_slot_mapping_buffer ) # Determine allowed attention backends once during initialization. @@ -274,6 +283,21 @@ def __init__( 1, len(self.tree_choices) + 1, device=device, dtype=torch.int32 ).repeat(max_batch_size, 1) + @cached_property + def layer_names_to_kv_cache_gid(self) -> dict[str, int]: + return kv_cache_group_id_by_layer(self.runner.kv_cache_config) + + def pick_first_layer_common_attn_metadata( + self, cm_by_gid: CommonAttnMetadataByGid + ) -> CommonAttentionMetadata: + """Pick the CommonAttentionMetadata for the drafter's first attention layer. + + The drafter may have attention layers in different KV cache groups. + This method picks the CAM corresponding to the first attention layer. + """ + gid = self.layer_names_to_kv_cache_gid[self.attn_layer_names[0]] + return cm_by_gid[gid] + def _raise_if_padded_drafter_batch_disabled(self): if self.speculative_config.disable_padded_drafter_batch: raise NotImplementedError( @@ -337,22 +361,52 @@ def _set_positions(self, num_tokens: int, positions: torch.Tensor): positions = positions[0] self.positions[:num_tokens] = positions + def new_slot_mapping_buffer(self) -> torch.Tensor: + """Create a new slot-mapping buffer for one KV cache group.""" + return torch.zeros( + self.max_num_tokens, + dtype=torch.int64, + device=self.hidden_states.device, + ) + def _get_slot_mapping( - self, - num_tokens: int, - slot_mapping: torch.Tensor | None = None, - ) -> dict[str, torch.Tensor]: - """Return slot_mapping dict for EAGLE layers. + self, cm_by_gid: CommonAttnMetadataByGid + ) -> dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]: + """Return slot_mapping dict for EAGLE/draft layers during inference. - If slot_mapping is provided, copies it into the buffer first. + Copies per-group slot mappings from CommonAttentionMetadata into + per-gid buffers and returns views into those buffers. """ - if slot_mapping is not None: - num_actual = slot_mapping.shape[0] - self._slot_mapping_buffer[:num_actual].copy_(slot_mapping) - if num_tokens > num_actual: - self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID) + slot_mapping_by_gid = dict(self._set_and_get_slot_mapping_bufs(cm_by_gid)) + return get_slot_mappings_by_layer( + self.runner.kv_cache_config, + slot_mapping_by_gid, + ubatch_slices=None, + ) - view = self._slot_mapping_buffer[:num_tokens] + def _set_and_get_slot_mapping_bufs( + self, cm_by_gid: CommonAttnMetadataByGid + ) -> Iterator[tuple[int, torch.Tensor]]: + """ + Set the slot mapping buffers for each KV cache group. + Then yield the buffer for each KV cache group ID. + """ + for gid, cm in cm_by_gid.items(): + src = cm.slot_mapping + src_len = src.shape[0] + buf = self._slot_mapping_buffer_by_gid[gid] + n = min(cm.num_actual_tokens, src_len) + buf[:n].copy_(src[:n]) + if n < src_len: + buf[n:src_len].fill_(PADDING_SLOT_ID) + yield gid, buf[:src_len] + + def _get_slot_mapping_dummy_run(self, num_tokens: int) -> dict[str, torch.Tensor]: + """Return slot_mapping dict for EAGLE/draft layers during dummy run. + Uses per-gid buffer (gid 0). + """ + buf = self._slot_mapping_buffer_by_gid[0] + view = buf[:num_tokens] return {name: view for name in self.attn_layer_names + self.indexer_layer_names} def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: @@ -383,16 +437,15 @@ def propose( # [batch_size] next_token_ids: torch.Tensor, token_indices_to_sample: torch.Tensor | None, - common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, + # CommonAttentionMetadata for each KV cache group ID. + cm_by_gid: CommonAttnMetadataByGid, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, num_rejected_tokens_gpu: torch.Tensor | None = None, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, ) -> torch.Tensor: - batch_size = common_attn_metadata.batch_size() - if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) target_hidden_states = self.model.combine_hidden_states( @@ -400,28 +453,35 @@ def propose( ) assert target_hidden_states.shape[-1] == self.hidden_size - num_tokens, token_indices_to_sample, common_attn_metadata = ( - self.set_inputs_first_pass( - target_token_ids=target_token_ids, - next_token_ids=next_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - token_indices_to_sample=token_indices_to_sample, - cad=common_attn_metadata, - num_rejected_tokens_gpu=num_rejected_tokens_gpu, - ) + num_tokens, token_indices_to_sample, cm_by_gid = self.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=token_indices_to_sample, + cm_by_gid=cm_by_gid, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, ) assert self.runner is not None + common_attn_metadata = self.pick_first_layer_common_attn_metadata(cm_by_gid) + batch_size = common_attn_metadata.batch_size() - if self.attn_metadata_builder is None: - attn_metadata_builder = self._get_attention_metadata_builder() - else: - attn_metadata_builder = self.attn_metadata_builder + # Build attention metadata for each KV cache group + # then build it for each layer. + attn_metadata_by_gid: dict[int, AttentionMetadata] = {} + for gid, cm in cm_by_gid.items(): + builder = self._get_metadata_builder(gid) + attn_metadata = builder.build_for_drafting( + common_attn_metadata=cm, draft_index=0 + ) + attn_metadata_by_gid[gid] = attn_metadata + + per_layer_attn_metadata: dict[str, AttentionMetadata] = {} + for layer_name in self.attn_layer_names: + gid = self.layer_names_to_kv_cache_gid[layer_name] + per_layer_attn_metadata[layer_name] = attn_metadata_by_gid[gid] - attn_metadata = attn_metadata_builder.build_for_drafting( - common_attn_metadata=common_attn_metadata, draft_index=0 - ) # FIXME: support hybrid kv for draft model (remove separate indexer) if self.draft_indexer_metadata_builder: draft_indexer_metadata = ( @@ -432,11 +492,6 @@ def propose( ) else: draft_indexer_metadata = None - # At this moment, we assume all eagle layers belong to the same KV - # cache group, thus using the same attention metadata. - per_layer_attn_metadata = {} - for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata for layer_name in self.indexer_layer_names: assert draft_indexer_metadata is not None @@ -482,9 +537,7 @@ def propose( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, - slot_mapping=self._get_slot_mapping( - num_input_tokens, common_attn_metadata.slot_mapping - ), + slot_mapping=self._get_slot_mapping(cm_by_gid=cm_by_gid), ): ret_hidden_states = self.model(**model_kwargs) if not self.model_returns_tuple(): @@ -524,6 +577,7 @@ def propose( hidden_states=hidden_states, common_attn_metadata=common_attn_metadata, slot_mappings=slot_mappings, + cm_by_gid=cm_by_gid, ) # [batch_size, num_tree_tokens] return torch.cat(draft_token_ids_list, dim=1) @@ -620,38 +674,55 @@ def propose( if common_attn_metadata._num_computed_tokens_cpu is not None: common_attn_metadata._num_computed_tokens_cpu += 1 - # Compute the slot mapping. - block_size = attn_metadata_builder.kv_cache_spec.block_size - if self.uses_mrope: - # all dimensions of positions are the same - block_numbers = clamped_positions[0] // block_size - else: - block_numbers = clamped_positions // block_size - block_ids = common_attn_metadata.block_table_tensor.gather( - dim=1, index=block_numbers.view(-1, 1) - ) - block_ids = block_ids.view(-1) - if self.uses_mrope: - common_attn_metadata.slot_mapping = ( - block_ids * block_size + clamped_positions[0] % block_size + # Update all other CommonAttentionMetadata objects with the updated + # information computed above. + for gid, cm in cm_by_gid.items(): + cm_by_gid[gid] = common_attn_metadata.replace( + slot_mapping=cm.slot_mapping, + block_table_tensor=cm.block_table_tensor, ) - else: - common_attn_metadata.slot_mapping = ( - block_ids * block_size + clamped_positions % block_size + + # Compute the slot mapping for each kv-cache group + attn_metadata_by_gid = {} + for gid, cm in cm_by_gid.items(): + blk_table_tensor = cm.block_table_tensor + + # Compute per-group block_size and block_numbers + builder = self._get_metadata_builder(gid) + group_block_size = builder.kv_cache_spec.block_size + if self.uses_mrope: + group_block_numbers = clamped_positions[0] // group_block_size + else: + group_block_numbers = clamped_positions // group_block_size + + block_ids = blk_table_tensor.gather( + dim=1, index=group_block_numbers.view(-1, 1) ) - # Mask out the slot mappings that exceed the max model length. - # Otherwise, the KV cache will be inadvertently updated with the - # padding tokens. - common_attn_metadata.slot_mapping.masked_fill_( - exceeds_max_model_len, PADDING_SLOT_ID - ) + block_ids = block_ids.view(-1) + if self.uses_mrope: + slot_mapping = ( + block_ids * group_block_size + + clamped_positions[0] % group_block_size + ) + else: + slot_mapping = ( + block_ids * group_block_size + + clamped_positions % group_block_size + ) + slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + + new_cm = cm_by_gid[gid].replace(slot_mapping=slot_mapping) + attn_metadata = builder.build_for_drafting( + common_attn_metadata=new_cm, + draft_index=token_index + 1, + ) + cm_by_gid[gid] = new_cm + attn_metadata_by_gid[gid] = attn_metadata # Rebuild attention metadata - attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore - common_attn_metadata=common_attn_metadata, draft_index=token_index + 1 - ) for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata + gid = self.layer_names_to_kv_cache_gid[layer_name] + per_layer_attn_metadata[layer_name] = attn_metadata_by_gid[gid] # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids @@ -681,9 +752,7 @@ def propose( num_tokens=input_batch_size, num_tokens_across_dp=batch_size_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, - slot_mapping=self._get_slot_mapping( - input_batch_size, common_attn_metadata.slot_mapping - ), + slot_mapping=self._get_slot_mapping(cm_by_gid), ): ret_hidden_states = self.model(**model_kwargs) if not self.model_returns_tuple(): @@ -708,9 +777,10 @@ def set_inputs_first_pass( target_positions: torch.Tensor, target_hidden_states: torch.Tensor, token_indices_to_sample: torch.Tensor | None, - cad: CommonAttentionMetadata, + cm_by_gid: CommonAttnMetadataByGid, num_rejected_tokens_gpu: torch.Tensor | None, - ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: + ) -> tuple[int, torch.Tensor, CommonAttnMetadataByGid]: + cad = self.pick_first_layer_common_attn_metadata(cm_by_gid) if not self.needs_extra_input_slots: # Default EAGLE pathway: no reshaping of input tensors needed. # Simply rotate the input ids and leave the positions unchanged, @@ -733,7 +803,7 @@ def set_inputs_first_pass( self.hidden_states[:num_tokens] = target_hidden_states - return num_tokens, token_indices_to_sample, cad + return num_tokens, token_indices_to_sample, cm_by_gid else: assert self.is_rejected_token_mask is not None assert self.is_masked_token_mask is not None @@ -812,33 +882,32 @@ def set_inputs_first_pass( ) # 2. - # Recompute the slot mapping based on the new positions and - # rejection mask. - builder = ( - self._get_attention_metadata_builder() - if self.attn_metadata_builder is None - else self.attn_metadata_builder - ) - new_slot_mapping = compute_new_slot_mapping( - cad=cad, - new_positions=self.positions[:total_num_output_tokens], - is_rejected_token_mask=self.is_rejected_token_mask[ - :total_num_output_tokens - ], - block_size=builder.kv_cache_spec.block_size, - num_new_tokens=self.net_num_new_slots_per_request, - max_model_len=self.max_model_len, - ) + # Recompute the slot mapping per KV cache group, based on the + # new positions and rejection mask. + new_cm_by_gid: CommonAttnMetadataByGid = {} + for gid, cm in cm_by_gid.items(): + builder = self._get_metadata_builder(gid) + new_slot_mapping = compute_new_slot_mapping( + cad=cm, + new_positions=self.positions[:total_num_output_tokens], + is_rejected_token_mask=self.is_rejected_token_mask[ + :total_num_output_tokens + ], + block_size=builder.kv_cache_spec.block_size, + num_new_tokens=self.net_num_new_slots_per_request, + max_model_len=self.max_model_len, + ) - # 3. Update the common attention metadata with the new (meta)data - new_cad = extend_all_queries_by_N( - cad, - N=self.net_num_new_slots_per_request, - arange=self.arange, - new_slot_mapping=new_slot_mapping, - ) + # 3. Update the common attention metadata with the new (meta)data + new_cm = extend_all_queries_by_N( + cm, + N=self.net_num_new_slots_per_request, + arange=self.arange, + new_slot_mapping=new_slot_mapping, + ) + new_cm_by_gid[gid] = new_cm - return total_num_output_tokens, token_indices_to_sample, new_cad + return total_num_output_tokens, token_indices_to_sample, new_cm_by_gid def model_returns_tuple(self) -> bool: return self.method not in ("mtp", "draft_model") @@ -878,7 +947,7 @@ def prepare_next_token_ids_cpu( def prepare_next_token_ids_padded( self, - common_attn_metadata: CommonAttentionMetadata, + cm_by_gid: CommonAttnMetadataByGid, sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, @@ -891,6 +960,7 @@ def prepare_next_token_ids_padded( is not sampled and comes from `request.get_token_id()` instead. This is denoted the "backup" token id. It also counts rejected tokens via `sampled_token_ids`. """ + common_attn_metadata = self.pick_first_layer_common_attn_metadata(cm_by_gid) # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs self.backup_next_token_ids.np[:num_reqs] = np.array( @@ -936,10 +1006,10 @@ def prepare_next_token_ids_padded( def prepare_inputs_padded( self, - common_attn_metadata: CommonAttentionMetadata, + cm_by_gid: CommonAttnMetadataByGid, spec_decode_metadata: SpecDecodeMetadata, valid_sampled_tokens_count: torch.Tensor, - ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + ) -> tuple[CommonAttnMetadataByGid, torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding It updates the common_attn_metadata for speculative decoding, @@ -948,6 +1018,9 @@ def prepare_inputs_padded( used as padding and filtered out later by `token_indices_to_sample`. No blocking CPU operations should be introduced in this function. """ + # Pick the CAM for the drafter's first attention layer as the base + common_attn_metadata = self.pick_first_layer_common_attn_metadata(cm_by_gid) + num_reqs = common_attn_metadata.num_reqs device = valid_sampled_tokens_count.device @@ -973,7 +1046,8 @@ def prepare_inputs_padded( total_num_tokens = query_start_loc_cpu[-1].item() - spec_common_attn_metadata = CommonAttentionMetadata( + # First create the new base CAM with transformed shared fields + new_cam = CommonAttentionMetadata( query_start_loc=common_attn_metadata.query_start_loc, seq_lens=common_attn_metadata.seq_lens, query_start_loc_cpu=query_start_loc_cpu, @@ -983,14 +1057,24 @@ def prepare_inputs_padded( num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), - block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens], + # block-table and slot-mapping are set below + block_table_tensor=torch.empty(0), + slot_mapping=torch.empty(0), causal=True, dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) + # Then create new CAMs for all groups, preserving each group's + # block_table_tensor and slot_mapping + new_cm_by_gid: CommonAttnMetadataByGid = {} + for gid, cm in cm_by_gid.items(): + new_cm_by_gid[gid] = new_cam.replace( + block_table_tensor=cm.block_table_tensor, + slot_mapping=cm.slot_mapping[:total_num_tokens], + ) + return ( - spec_common_attn_metadata, + new_cm_by_gid, token_indices_to_sample, num_rejected_tokens_gpu, ) @@ -1005,6 +1089,7 @@ def propose_tree( # [num_tokens, hidden_size] hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, + cm_by_gid: CommonAttnMetadataByGid, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, @@ -1129,9 +1214,7 @@ def propose_tree( self.vllm_config, num_tokens=num_input_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, - slot_mapping=self._get_slot_mapping( - num_input_tokens, attn_metadata.slot_mapping - ), + slot_mapping=self._get_slot_mapping(cm_by_gid), ): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], @@ -1170,16 +1253,19 @@ def propose_tree( def prepare_inputs( self, - common_attn_metadata: CommonAttentionMetadata, + cm_by_gid: CommonAttnMetadataByGid, sampled_token_ids: list[list[int]], num_draft_tokens: list[int], - ) -> tuple[CommonAttentionMetadata, torch.Tensor]: + ) -> tuple[CommonAttnMetadataByGid, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding. It updates to the common_attn_metadata to account for the rejected tokens (and newly sampled tokens). It also returns the token indices of the tokens that should be fed to the speculator. """ + # Pick the CAM for the drafter's first attention layer as the base + common_attn_metadata = self.pick_first_layer_common_attn_metadata(cm_by_gid) + # E.g. # common_attn_metadata.query_start_loc{_cpu}: # [0, q1, q1 + q2, q1 + q2 + q3] @@ -1252,7 +1338,8 @@ def prepare_inputs( token_indices_np = token_offsets + old_query_start_locs_expanded token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True) - spec_common_attn_metadata = CommonAttentionMetadata( + # First create the new base CAM with transformed shared fields + new_cam = CommonAttentionMetadata( query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), query_start_loc_cpu=new_query_start_loc_cpu, @@ -1262,13 +1349,23 @@ def prepare_inputs( num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), max_seq_len=new_seq_lens_cpu.max().item(), - block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping[token_indices], + # block-table and slot-mapping are set below + block_table_tensor=torch.empty(0), + slot_mapping=torch.empty(0), causal=True, dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) - return spec_common_attn_metadata, token_indices + # Then create new CAMs for all groups, preserving each group's + # block_table_tensor and slot_mapping + new_cm_by_gid: CommonAttnMetadataByGid = {} + for gid, cm in cm_by_gid.items(): + new_cm_by_gid[gid] = new_cam.replace( + block_table_tensor=cm.block_table_tensor, + slot_mapping=cm.slot_mapping[token_indices], + ) + + return new_cm_by_gid, token_indices def get_model_name(self, model: nn.Module) -> str: if hasattr(model, "module"): # multi-GPU @@ -1558,7 +1655,9 @@ def dummy_run( and slot_mappings is not None and self.attn_layer_names[0] in slot_mappings ): - slot_mapping_dict = self._get_slot_mapping(num_input_tokens) + slot_mapping_dict = self._get_slot_mapping_dummy_run( + num_tokens=num_input_tokens + ) else: slot_mapping_dict = slot_mappings or {} @@ -1586,30 +1685,8 @@ def dummy_run( kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] self.model(**kwargs) - def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: - """Find and return the attention metadata builders for EAGLE layers. - - Returns: - The metadata builders for EAGLE layers. - - Raises: - AssertionError: If no metadata builders are found for EAGLE layers. - """ - builder = None - chosen_layer = self.attn_layer_names[0] - - for kv_cache_group in self.runner.attn_groups: - for attn_group in kv_cache_group: - if chosen_layer in attn_group.layer_names: - builder = attn_group.get_metadata_builder() - break - if builder is not None: - break - - assert builder is not None, ( - "Failed to find attention metadata builder for EAGLE layers." - ) - return builder + def _get_metadata_builder(self, kv_cache_group_id: int) -> AttentionMetadataBuilder: + return self.runner.attn_groups[kv_cache_group_id][0].get_metadata_builder() def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool: """ @@ -1627,29 +1704,6 @@ def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool: use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True) return use_aux_hidden_state - def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: - """ - Validate that all drafting layers belong to the same KVCacheGroup. - Need this assumption to ensure all drafting layers can use the - same AttentionMetadata. - May extend to multiple AttentionMetadata in the future. - """ - kv_cache_groups: dict[str, int] = {} - for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): - for layer_name in kv_cache_group.layer_names: - kv_cache_groups[layer_name] = id - assert ( - len( - set( - [ - kv_cache_groups[layer_name] - for layer_name in self.attn_layer_names - ] - ) - ) - == 1 - ), "All drafting layers should belong to the same kv cache group" - def _pad_batch_across_dp( self, num_tokens_unpadded: int, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ba1428c42ee4..d16bc0b6cf1c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -111,6 +111,7 @@ AttentionMetadataBuilder, AttentionType, CommonAttentionMetadata, + CommonAttnMetadataByGid, MultipleOf, ) from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder @@ -119,6 +120,7 @@ get_dcp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills, ) +from vllm.v1.core.kv_cache_utils import get_slot_mappings_by_layer from vllm.v1.core.sched.output import NewRequestData from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import ( @@ -317,13 +319,13 @@ class ExecuteModelState(NamedTuple): scheduler_output: "SchedulerOutput" logits: torch.Tensor spec_decode_metadata: SpecDecodeMetadata | None - spec_decode_common_attn_metadata: CommonAttentionMetadata | None hidden_states: torch.Tensor sample_hidden_states: torch.Tensor aux_hidden_states: list[torch.Tensor] | None ec_connector_output: ECConnectorOutput | None cudagraph_stats: CUDAGraphStat | None slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None + cm_by_gid: CommonAttnMetadataByGid class GPUModelRunner( @@ -1689,13 +1691,10 @@ def _build_attention_metadata( num_scheduled_tokens: dict[str, int] | None = None, cascade_attn_prefix_lens: list[list[int]] | None = None, slot_mappings: dict[int, torch.Tensor] | None = None, - ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: - """ - :return: tuple[attn_metadata, spec_decode_common_attn_metadata] - """ + ) -> tuple[PerLayerAttnMetadata, CommonAttnMetadataByGid]: # Attention metadata is not needed for attention free models if len(self.kv_cache_config.kv_cache_groups) == 0: - return {}, None + return {}, dict() num_tokens_padded = num_tokens_padded or num_tokens num_reqs_padded = num_reqs_padded or num_reqs @@ -1793,6 +1792,10 @@ def _get_block_table(kv_cache_gid: int): tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata ] = {} + # CommonAttentionMetadata objects for each KV cache group. + # These are passed to the drafter to support multi-group KV cache in drafters. + cm_by_gid: CommonAttnMetadataByGid = {} + def _build_attn_group_metadata( kv_cache_gid: int, attn_gid: int, @@ -1856,7 +1859,6 @@ def _build_attn_group_metadata( # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. - spec_decode_common_attn_metadata = None for kv_cache_gid, kv_cache_group in enumerate(kv_cache_groups): cm = copy(cm_base) # shallow copy @@ -1868,16 +1870,10 @@ def _build_attn_group_metadata( num_reqs_padded, for_cudagraph_capture=for_cudagraph_capture, ) - if kv_cache_gid > 0: - cm.block_table_tensor = _get_block_table(kv_cache_gid) - cm.slot_mapping = slot_mappings[kv_cache_gid] - - if self.speculative_config and spec_decode_common_attn_metadata is None: - if isinstance(self.drafter, EagleProposer): - if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: - spec_decode_common_attn_metadata = cm - else: - spec_decode_common_attn_metadata = cm + cm.block_table_tensor = _get_block_table(kv_cache_gid) + cm.slot_mapping = slot_mappings[kv_cache_gid] + + cm_by_gid[kv_cache_gid] = cm for attn_gid in range(len(self.attn_groups[kv_cache_gid])): if ubatch_slices is not None: @@ -1907,17 +1903,17 @@ def _build_attn_group_metadata( for _metadata in attn_metadata.values(): _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] - if spec_decode_common_attn_metadata is not None and ( + if cm_by_gid and ( num_reqs != num_reqs_padded or num_tokens != num_tokens_padded ): # Currently the drafter still only uses piecewise cudagraphs (and modifies # the attention metadata in directly), and therefore does not want to use # padded attention metadata. - spec_decode_common_attn_metadata = ( - spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) - ) + cm_by_gid = { + gid: cm.unpadded(num_tokens, num_reqs) for gid, cm in cm_by_gid.items() + } - return attn_metadata, spec_decode_common_attn_metadata + return attn_metadata, cm_by_gid def _compute_cascade_attn_prefix_lens( self, @@ -3122,8 +3118,8 @@ def _determine_batch_execution_and_padding( has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) - dispatch_cudagraph = ( - lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch( + dispatch_cudagraph = lambda num_tokens, disable_full: ( + self.cudagraph_dispatcher.dispatch( num_tokens=num_tokens, has_lora=has_lora, uniform_decode=uniform_decode, @@ -3262,6 +3258,7 @@ def _get_slot_mappings( A tuple of: - slot_mappings_by_gid: dict[int, torch.Tensor] for attention metadata - slot_mappings_by_layer: dict[str, torch.Tensor] or list for ForwardContext + Or None if KV cache config is not available. """ if not ( hasattr(self, "kv_cache_config") @@ -3295,22 +3292,9 @@ def _get_slot_mapping(kv_cache_gid: int): gid: _get_slot_mapping(gid) for gid, _ in enumerate(self.kv_cache_config.kv_cache_groups) } - - slot_mappings_by_layer: dict[str, torch.Tensor] = {} - for gid, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups): - slot_mapping = slot_mappings_by_gid[gid] - for layer_name in kv_cache_group.layer_names: - slot_mappings_by_layer[layer_name] = slot_mapping - - if ubatch_slices is not None: - result: list[dict[str, torch.Tensor]] = [] - for ubatch in ubatch_slices: - sliced_mappings: dict[str, torch.Tensor] = {} - for layer_name, slot_mapping in slot_mappings_by_layer.items(): - sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice] - result.append(sliced_mappings) - return slot_mappings_by_gid, result - + slot_mappings_by_layer = get_slot_mappings_by_layer( + self.kv_cache_config, slot_mappings_by_gid, ubatch_slices + ) return slot_mappings_by_gid, slot_mappings_by_layer @torch.inference_mode() @@ -3481,20 +3465,18 @@ def execute_model( ubatch_slices=ubatch_slices_padded, ) - attn_metadata, spec_decode_common_attn_metadata = ( - self._build_attention_metadata( - num_tokens=num_tokens_unpadded, - num_tokens_padded=num_tokens_padded if pad_attn else None, - num_reqs=num_reqs, - num_reqs_padded=num_reqs_padded if pad_attn else None, - max_query_len=max_num_scheduled_tokens, - ubatch_slices=ubatch_slices_attn, - logits_indices=logits_indices, - use_spec_decode=use_spec_decode, - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - cascade_attn_prefix_lens=cascade_attn_prefix_lens, - slot_mappings=slot_mappings_by_group, - ) + attn_metadata, cm_by_gid = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded if pad_attn else None, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, + slot_mappings=slot_mappings_by_group, ) ( @@ -3611,13 +3593,13 @@ def execute_model( scheduler_output, logits, spec_decode_metadata, - spec_decode_common_attn_metadata, hidden_states, sample_hidden_states, aux_hidden_states, ec_connector_output, cudagraph_stats, slot_mappings, + cm_by_gid, ) self.kv_connector_output = kv_connector_output return None @@ -3650,13 +3632,13 @@ def sample_tokens( scheduler_output, logits, spec_decode_metadata, - spec_decode_common_attn_metadata, hidden_states, sample_hidden_states, aux_hidden_states, ec_connector_output, cudagraph_stats, slot_mappings, + cm_by_gid, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -3688,7 +3670,7 @@ def sample_tokens( self.input_batch.prev_sampled_token_ids = None def propose_draft_token_ids(sampled_token_ids): - assert spec_decode_common_attn_metadata is not None + assert cm_by_gid is not None with record_function_or_nullcontext("gpu_model_runner: draft"): self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, @@ -3698,16 +3680,19 @@ def propose_draft_token_ids(sampled_token_ids): sample_hidden_states, aux_hidden_states, spec_decode_metadata, - spec_decode_common_attn_metadata, slot_mappings, + cm_by_gid, ) self._copy_draft_token_ids_to_cpu(scheduler_output) spec_config = self.speculative_config propose_drafts_after_bookkeeping = False if spec_config is not None: - input_fits_in_drafter = spec_decode_common_attn_metadata is not None and ( - spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens + # Use any CAM from cm_by_gid for the max_seq_len check (it's a global + # batch property shared across all KV cache groups). + any_cm = next(iter(cm_by_gid.values()), None) if cm_by_gid else None + input_fits_in_drafter = any_cm is not None and ( + any_cm.max_seq_len + self.num_spec_tokens <= self.effective_drafter_max_model_len ) use_gpu_toks = ( @@ -3721,10 +3706,9 @@ def propose_draft_token_ids(sampled_token_ids): if input_fits_in_drafter: propose_draft_token_ids(sampled_token_ids) elif self.valid_sampled_token_count_event is not None: - assert spec_decode_common_attn_metadata is not None next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - spec_decode_common_attn_metadata, + cm_by_gid, sampled_token_ids, self.requests, self.input_batch, @@ -3946,8 +3930,8 @@ def propose_draft_token_ids( sample_hidden_states: torch.Tensor, aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, - common_attn_metadata: CommonAttentionMetadata, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, + cm_by_gid: CommonAttnMetadataByGid, ) -> list[list[int]] | torch.Tensor: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config @@ -4023,7 +4007,7 @@ def propose_draft_token_ids( ) next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - common_attn_metadata, + cm_by_gid, sampled_token_ids, self.requests, self.input_batch, @@ -4050,8 +4034,8 @@ def propose_draft_token_ids( else: if spec_config.disable_padded_drafter_batch: token_indices_to_sample = None - common_attn_metadata, token_indices = self.drafter.prepare_inputs( - common_attn_metadata, + cm_by_gid, token_indices = self.drafter.prepare_inputs( + cm_by_gid, sampled_token_ids, spec_decode_metadata.num_draft_tokens, ) @@ -4066,15 +4050,18 @@ def propose_draft_token_ids( target_hidden_states = hidden_states[token_indices] else: ( - common_attn_metadata, + cm_by_gid, token_indices_to_sample, num_rejected_tokens_gpu, ) = self.drafter.prepare_inputs_padded( - common_attn_metadata, + cm_by_gid, spec_decode_metadata, valid_sampled_tokens_count, ) - total_num_tokens = common_attn_metadata.num_actual_tokens + + # Get any CAM to access the shared num_actual_tokens + any_cm = next(iter(cm_by_gid.values())) + total_num_tokens = any_cm.num_actual_tokens # When padding the batch, token_indices is just a range target_token_ids = self.input_ids.gpu[:total_num_tokens] target_positions = self._get_positions(total_num_tokens) @@ -4101,10 +4088,10 @@ def propose_draft_token_ids( next_token_ids=next_token_ids, token_indices_to_sample=token_indices_to_sample, sampling_metadata=sampling_metadata, - common_attn_metadata=common_attn_metadata, mm_embed_inputs=mm_embed_inputs, num_rejected_tokens_gpu=num_rejected_tokens_gpu, slot_mappings=slot_mappings, + cm_by_gid=cm_by_gid, ) return draft_token_ids @@ -6072,15 +6059,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config, kernel_block_sizes ) - if self.speculative_config and ( - self.speculative_config.use_eagle() - or self.speculative_config.uses_draft_model() - ): - assert isinstance(self.drafter, EagleProposer | DraftModelProposer) - # validate all draft model layers belong to the same kv cache - # group - self.drafter.validate_same_kv_cache_group(kv_cache_config) - if has_kv_transfer_group(): kv_transfer_group = get_kv_transfer_group() if self.cross_layers_kv_cache is not None: