diff --git a/examples/offline_inference/extract_hidden_states.py b/examples/offline_inference/extract_hidden_states.py new file mode 100644 index 000000000000..61299101cb47 --- /dev/null +++ b/examples/offline_inference/extract_hidden_states.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import tempfile + +from safetensors import safe_open + +from vllm import LLM, SamplingParams + +# Example: Using the custom "extract_hidden_states" speculator method and +# ExampleHiddenStatesConnector to extract and save hidden states from vllm + +with tempfile.TemporaryDirectory() as tmpdirname: + llm = LLM( + model="Qwen/Qwen3-8B", # Your target model + speculative_config={ + "method": "extract_hidden_states", + "num_speculative_tokens": 1, + "draft_model_config": { + "hf_config": { + "eagle_aux_hidden_state_layer_ids": [ # Target model layer indices + 1, + 2, + 3, + 4, + ], + } + }, + }, + kv_transfer_config={ + "kv_connector": "ExampleHiddenStatesConnector", + "kv_role": "kv_producer", + "kv_connector_extra_config": { + "shared_storage_path": tmpdirname, + }, + }, + ) + + prompts = ["Generate a sentence with hidden states", "Write a python function"] + sampling_params = SamplingParams(max_tokens=1) + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + print("\nPrompt:", output.prompt) + print("Prompt token ids:", output.prompt_token_ids) + + hidden_states_path = output.kv_transfer_params.get("hidden_states_path") + assert hidden_states_path is not None + print("Prompt hidden states path:", hidden_states_path) + + with safe_open(hidden_states_path, "pt") as f: + token_ids = f.get_tensor("token_ids") + hidden_states = f.get_tensor("hidden_states") + + print("Extracted token ids:", token_ids) # Matches prompt token ids + print( + "Extracted hidden states shape:", hidden_states.shape + ) # [num_hidden_layers, prompt len, hidden size] + print("Extracted hidden states:", hidden_states) diff --git a/tests/models/registry.py b/tests/models/registry.py index c8e47ad502f9..6dd5f3d08254 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -108,7 +108,7 @@ class _HfExamplesInfo: use_original_num_layers: bool = False """ - If True, use the original number of layers from the model config + If True, use the original number of layers from the model config instead of minimal layers for testing. """ @@ -1160,6 +1160,10 @@ def check_available_online( speculative_model="LGAI-EXAONE/K-EXAONE-236B-A23B", min_transformers_version="5.1.0", ), + "ExtractHiddenStatesModel": _HfExamplesInfo( + "Qwen/Qwen3-8B", + speculative_method="extract_hidden_states", + ), "Glm4MoeMTPModel": _HfExamplesInfo( "zai-org/GLM-4.5", speculative_model="zai-org/GLM-4.5", diff --git a/tests/v1/kv_connector/extract_hidden_states_integration/__init__.py b/tests/v1/kv_connector/extract_hidden_states_integration/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/kv_connector/extract_hidden_states_integration/predictable_llama.py b/tests/v1/kv_connector/extract_hidden_states_integration/predictable_llama.py new file mode 100644 index 000000000000..5b130e9ac679 --- /dev/null +++ b/tests/v1/kv_connector/extract_hidden_states_integration/predictable_llama.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Predictable dummy model for testing extract_hidden_states. + +Subclasses LlamaForCausalLM but overrides the model to produce deterministic +hidden states: layer i outputs values equal to (i). +""" + +from collections.abc import Iterable + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.sequence import IntermediateTensors + + +class PredictableLlamaModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.aux_hidden_state_layers = tuple[int, ...]() + + # Create minimal embed_tokens for embedding + from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, + ) + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + ) + + # Required for pipeline parallelism + from vllm.model_executor.models.utils import ( + make_empty_intermediate_tensors_factory, + ) + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """Embed input IDs.""" + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + **extra_layer_kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """Forward pass that produces predictable outputs. + + Returns: + If aux_hidden_state_layers is set: (hidden_states, aux_hidden_states) + Otherwise: hidden_states + """ + # Determine sequence length + if inputs_embeds is not None: + seq_len = inputs_embeds.shape[0] + device = inputs_embeds.device + elif input_ids is not None: + seq_len = input_ids.shape[0] if input_ids.ndim == 1 else input_ids.shape[-1] + device = input_ids.device + else: + raise ValueError("Either input_ids or inputs_embeds must be provided") + + # Final hidden states (last layer value) + hidden_states = torch.full( + (seq_len, self.config.hidden_size), + fill_value=float(self.config.num_hidden_layers), + device=device, + dtype=torch.bfloat16, + ) + + # Check if we need auxiliary hidden states + if len(self.aux_hidden_state_layers) > 0: + aux_hidden_states = [] + for layer_idx in self.aux_hidden_state_layers: + # Fill with (layer_idx) for predictability + layer_hidden = torch.full( + (seq_len, self.config.hidden_size), + fill_value=float(layer_idx), + device=device, + dtype=torch.bfloat16, + ) + aux_hidden_states.append(layer_hidden) + + return hidden_states, aux_hidden_states + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Skip weight loading.""" + return set() + + +class PredictableLlamaForCausalLM(LlamaForCausalLM): + """Predictable Llama model for testing. + + Overrides _init_model to use PredictableLlamaModel instead of LlamaModel. + """ + + def _init_model( + self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] | None = None, + ): + """Initialize with predictable model.""" + return PredictableLlamaModel(vllm_config=vllm_config, prefix=prefix) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Skip weight loading for dummy model.""" + return set() diff --git a/tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py b/tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py new file mode 100644 index 000000000000..6a8c64152fec --- /dev/null +++ b/tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import gc +import os + +import pytest +import torch +from safetensors import safe_open + +from vllm import LLM, ModelRegistry, SamplingParams + + +def get_and_check_output(output, expected_shape): + assert output.kv_transfer_params is not None + hidden_states_path = output.kv_transfer_params.get("hidden_states_path") + assert hidden_states_path is not None + assert os.path.exists(hidden_states_path) + + # Load and verify the saved tensors + with safe_open(hidden_states_path, "pt") as f: + # Check that token_ids and hidden_states are present + tensor_names = f.keys() + assert "token_ids" in tensor_names + assert "hidden_states" in tensor_names + + token_ids = f.get_tensor("token_ids") + hidden_states = f.get_tensor("hidden_states") + + prompt_token_ids = output.prompt_token_ids + assert torch.equal(token_ids, torch.tensor(prompt_token_ids)) + + assert hidden_states.shape == expected_shape + + # Verify hidden_states are not all zeros (i.e., they were actually computed) + assert not torch.allclose(hidden_states, torch.zeros_like(hidden_states)) + + return token_ids, hidden_states + + +@pytest.fixture(scope="module") +def predictable_llama_config_path(tmp_path_factory): + """Create a minimal LlamaConfig for PredictableLlamaForCausalLM.""" + from transformers import LlamaConfig, LlamaTokenizerFast + + config_dir = tmp_path_factory.mktemp("predictable_llama") + + # Create a minimal Llama config with small dimensions + config = LlamaConfig( + vocab_size=1000, + hidden_size=256, + intermediate_size=512, + num_hidden_layers=24, # Enough layers to test various layer_ids + num_attention_heads=4, + num_key_value_heads=4, + max_position_embeddings=128, + architectures=["PredictableLlamaForCausalLM"], + ) + + # Save config + config.save_pretrained(config_dir) + + # Create a simple tokenizer + tokenizer = LlamaTokenizerFast.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + cache_dir=os.path.expanduser("~/.cache/huggingface"), + ) + tokenizer.save_pretrained(config_dir) + + return str(config_dir) + + +@pytest.fixture(scope="module", autouse=True) +def register_predictable_model(): + """Register the PredictableLlamaForCausalLM model.""" + from .predictable_llama import PredictableLlamaForCausalLM + + if "PredictableLlamaForCausalLM" not in ModelRegistry.get_supported_archs(): + ModelRegistry.register_model( + "PredictableLlamaForCausalLM", PredictableLlamaForCausalLM + ) + yield + + +def test_extract_hidden_states_with_predictable_dummy_model( + predictable_llama_config_path, tmp_path +): + """Comprehensive test using a predictable dummy model with synthetic weights. + + The PredictableLlamaForCausalLM outputs deterministic hidden states where + each layer produces values equal to (layer_index). This test verifies: + 1. Hidden states are correctly extracted from requested layers + 2. Values match the expected predictable pattern + 3. Layer ordering is preserved correctly (non-sequential layer IDs) + 4. Multiple prompts of different lengths produce consistent layer values + """ + # Test with non-sequential layer ordering to verify correct association + layer_ids = [5, 2, 10] + num_layers = len(layer_ids) + + llm = LLM( + model=predictable_llama_config_path, + speculative_config={ + "method": "extract_hidden_states", + "num_speculative_tokens": 1, + "draft_model_config": { + "hf_config": {"eagle_aux_hidden_state_layer_ids": layer_ids} + }, + }, + kv_transfer_config={ + "kv_connector": "ExampleHiddenStatesConnector", + "kv_role": "kv_producer", + "kv_connector_extra_config": {"shared_storage_path": tmp_path}, + }, + max_model_len=128, + enforce_eager=True, + trust_remote_code=True, + load_format="dummy", # Don't try to load real weights + ) + + # Test with multiple prompts of different lengths + prompts = [ + "Short", + "Medium length", + "Much longer prompt with many tokens", + "Much longer prompt with many tokens", # repeated prompt + ] + sampling_params = SamplingParams(max_tokens=1, temperature=0.0) + hidden_size = llm.llm_engine.model_config.get_hidden_size() + outputs = llm.generate(prompts, sampling_params) + del llm + gc.collect() + + assert len(outputs) == len(prompts) + + for output in outputs: + # hidden_states shape is [prompt_len, num_hidden_layers, hidden_size] + expected_shape = ( + len(output.prompt_token_ids), + num_layers, + hidden_size, + ) + _token_ids, hidden_states = get_and_check_output(output, expected_shape) + + for idx, layer_id in enumerate(layer_ids): + layer_hidden = hidden_states[:, idx, :] + assert torch.allclose( + layer_hidden, + torch.full_like(layer_hidden, layer_id), + atol=1e-5, + ), ( + f"Layer {layer_id} at position {idx} should output {float(layer_id)}, " + f"but got mean={layer_hidden.mean():.3f}, " + f"min={layer_hidden.min():.3f}, max={layer_hidden.max():.3f}" + ) diff --git a/tests/v1/spec_decode/test_extract_hidden_states.py b/tests/v1/spec_decode/test_extract_hidden_states.py new file mode 100644 index 000000000000..af911e91d4b3 --- /dev/null +++ b/tests/v1/spec_decode/test_extract_hidden_states.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest import mock + +import pytest +import torch + +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, +) +from vllm.config import ( + AttentionConfig, + CacheConfig, + DeviceConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) +from vllm.config.load import LoadConfig +from vllm.platforms import current_platform +from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + +model_dir = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + + +def _create_proposer( + num_speculative_tokens: int = 1, + layer_ids: list[int] | None = None, +) -> ExtractHiddenStatesProposer: + """Create an ExtractHiddenStatesProposer for testing.""" + if layer_ids is None: + layer_ids = [1, 2, 3, 4] + + model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) + + speculative_config = SpeculativeConfig( + target_model_config=model_config, + target_parallel_config=ParallelConfig(), + method="extract_hidden_states", + num_speculative_tokens=num_speculative_tokens, + draft_model_config={ + "hf_config": { + "eagle_aux_hidden_state_layer_ids": layer_ids, + } + }, + ) + + device = current_platform.device_type + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device=device), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), + attention_config=AttentionConfig(), + ) + + return ExtractHiddenStatesProposer(vllm_config=vllm_config, device=device) + + +def test_proposer_initialization(): + """Test that the proposer initializes correctly with the right parameters.""" + layer_ids = [1, 2, 3, 4] + proposer = _create_proposer(num_speculative_tokens=1, layer_ids=layer_ids) + + assert proposer.num_hidden_states == len(layer_ids) + assert proposer.vllm_config.speculative_config is not None + assert proposer.vllm_config.speculative_config.num_speculative_tokens == 1 + + # Verify the hidden states buffer is correctly shaped + expected_shape = ( + proposer.max_num_tokens, + len(layer_ids), + proposer.hidden_size, + ) + assert proposer.hidden_states.shape == expected_shape + + +def test_proposer_initialization_missing_layer_ids(): + """Test that initialization fails when layer_ids are not provided.""" + model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) + + speculative_config = SpeculativeConfig( + target_model_config=model_config, + target_parallel_config=ParallelConfig(), + method="extract_hidden_states", + num_speculative_tokens=1, + draft_model_config={ + "hf_config": {} # Missing eagle_aux_hidden_state_layer_ids + }, + ) + + device = current_platform.device_type + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device=device), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), + attention_config=AttentionConfig(), + ) + + with pytest.raises( + ValueError, match="eagle_aux_hidden_state_layer_ids must be set" + ): + ExtractHiddenStatesProposer(vllm_config=vllm_config, device=device) + + +def test_prepare_next_token_ids_padded(): + """ + Test for prepare_next_token_ids_padded with extract_hidden_states. + + Since num_speculative_tokens == 1, sampled_token_ids has shape (batch_size, 1). + For each request we either use the sampled token (if valid and not discarded) + or a backup token from the request state. + """ + device = torch.device(current_platform.device_type) + + num_requests = 4 + batch_spec = BatchSpec( + seq_lens=[5] * num_requests, + query_lens=[5] * num_requests, + ) + + req_ids = [f"req_{i + 1}" for i in range(num_requests)] + mock_input_batch = mock.MagicMock(spec=InputBatch) + mock_input_batch.req_ids = req_ids + mock_input_batch.num_reqs = num_requests + mock_input_batch.vocab_size = 100 + + mock_requests = {} + for req_id in req_ids: + mock_request = mock.MagicMock(spec=CachedRequestState) + # Each request will have a backup next token id of 10, 20, 30, 40 + mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10 + mock_requests[req_id] = mock_request + + # explicitly discard the last request + discarded_req_mask = torch.tensor( + [False, False, False, True], dtype=torch.bool, device=device + ) + + # With num_speculative_tokens=1, sampled_token_ids has shape [batch_size, 1] + sampled_token_ids = torch.tensor( + [ + [1], # valid, use 1 + [4], # valid, use 4 + [-1], # invalid, use backup token "30" + [2], # explicitly discarded, use backup token "40" + ], + dtype=torch.int32, + device=device, + ) + + expected_next_token_ids_cpu = [1, 4, 30, 40] + expected_next_token_ids_tensor = torch.tensor( + expected_next_token_ids_cpu, dtype=torch.int32, device=device + ) + + proposer = _create_proposer(num_speculative_tokens=1) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + # valid_sampled_tokens_count tracks if token is valid (not -1 and in vocab range) + # It doesn't depend on whether the request is discarded + expected_valid_sampled_tokens_count = torch.tensor( + [1, 1, 0, 1], dtype=torch.int32, device=device + ) + + next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + mock_requests, + mock_input_batch, + discarded_req_mask, + ) + + assert torch.equal(next_token_ids, expected_next_token_ids_tensor) + assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count) + + +def test_propose(): + """ + Test the propose() method of ExtractHiddenStatesProposer. + + This should: + 1. Accept target hidden states and sampled token IDs + 2. Return the sampled tokens as "draft" tokens (shape [batch_size, 1]) + 3. Cache the hidden states in the model's KV cache + """ + device = torch.device(current_platform.device_type) + + # Setup test parameters + batch_size = 2 + num_tokens = 5 + num_hidden_layers = 4 + + proposer = _create_proposer( + num_speculative_tokens=1, layer_ids=list(range(num_hidden_layers)) + ) + hidden_size = proposer.hidden_size + + # Create mock model + model_mock = mock.MagicMock() + proposer.model = model_mock + + # Mock attention layer names + proposer.attn_layer_names = ["cache_only_layers.28"] + + # Mock attention metadata builder + mock_attn_metadata = mock.MagicMock() + mock_attn_metadata_builder = mock.MagicMock() + mock_attn_metadata_builder.build_for_drafting.return_value = mock_attn_metadata + proposer.attn_metadata_builder = mock_attn_metadata_builder + + # Create input tensors + batch_spec = BatchSpec( + seq_lens=[3, 2], + query_lens=[3, 2], + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + # Create target hidden states: list of tensors, one per layer + # Each tensor has shape [num_tokens, hidden_size] + target_hidden_states = [ + torch.randn(num_tokens, hidden_size, dtype=proposer.dtype, device=device) + for _ in range(num_hidden_layers) + ] + + # Sampled token IDs from target model + sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device) + + # Mock scheduler output + mock_scheduler_output = mock.MagicMock() + + # Call propose + with mock.patch( + "vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group" + ) as mock_has_kv: + mock_has_kv.return_value = False + + draft_tokens, kv_connector_output = proposer.propose( + sampled_token_ids=sampled_token_ids, + target_hidden_states=target_hidden_states, + common_attn_metadata=common_attn_metadata, + scheduler_output=mock_scheduler_output, + slot_mappings=None, + ) + + # Verify draft tokens match sampled tokens + # Shape should be [batch_size, 1] for num_speculative_tokens=1 + assert draft_tokens.shape == (batch_size, 1) + assert torch.equal(draft_tokens[:, 0], sampled_token_ids) + + # Verify the model was called + model_mock.assert_called_once() + + # Verify hidden states were copied to the buffer The stacked hidden states + # should have shape [num_tokens, num_hidden_layers, hidden_size] + expected_stacked = torch.stack(target_hidden_states, dim=1) + assert torch.allclose( + proposer.hidden_states[:num_tokens], expected_stacked, atol=1e-6 + ) + + +@pytest.mark.parametrize("num_hidden_layers", [1, 4, 8]) +def test_propose_different_layer_counts(num_hidden_layers): + """Test that propose works correctly with different numbers of hidden layers.""" + device = torch.device(current_platform.device_type) + + batch_size = 2 + num_tokens = 5 + + proposer = _create_proposer( + num_speculative_tokens=1, layer_ids=list(range(num_hidden_layers)) + ) + hidden_size = proposer.hidden_size + + # Setup mocks + model_mock = mock.MagicMock() + proposer.model = model_mock + proposer.attn_layer_names = ["cache_only_layers.28"] + + mock_attn_metadata_builder = mock.MagicMock() + mock_attn_metadata_builder.build_for_drafting.return_value = mock.MagicMock() + proposer.attn_metadata_builder = mock_attn_metadata_builder + + batch_spec = BatchSpec( + seq_lens=[3, 2], + query_lens=[3, 2], + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + # Create target hidden states + target_hidden_states = [ + torch.randn(num_tokens, hidden_size, dtype=proposer.dtype, device=device) + for _ in range(num_hidden_layers) + ] + + sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device) + mock_scheduler_output = mock.MagicMock() + + with mock.patch( + "vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group" + ) as mock_has_kv: + mock_has_kv.return_value = False + + draft_tokens, _ = proposer.propose( + sampled_token_ids=sampled_token_ids, + target_hidden_states=target_hidden_states, + common_attn_metadata=common_attn_metadata, + scheduler_output=mock_scheduler_output, + slot_mappings=None, + ) + + assert draft_tokens.shape == (batch_size, 1) + assert torch.equal(draft_tokens[:, 0], sampled_token_ids) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index c2bced7842d3..a950ba531ad2 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast +import copy from typing import TYPE_CHECKING, Any, Literal, get_args from pydantic import Field, SkipValidation, model_validator @@ -45,7 +46,7 @@ "pangu_ultra_moe_mtp", "step3p5_mtp", ] -EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] +EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes] SpeculativeMethod = Literal[ "ngram", "medusa", @@ -181,9 +182,22 @@ def compute_hash(self) -> str: the final hidden states. """ factors: list[Any] = [] - # Eagle3 affects the computation graph because it returns intermediate - # hidden states in addition to the final hidden state. - factors.append(self.method == "eagle3") + # Eagle3 and extract_hidden_states affect the computation graph because + # they return intermediate hidden states in addition to the final hidden state. + uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states") + factors.append(uses_aux_hidden_states) + + # The specific layers used also affect the computation graph + if uses_aux_hidden_states and self.draft_model_config is not None: + layer_ids = getattr( + self.draft_model_config.hf_config, + "eagle_aux_hidden_state_layer_ids", + None, + ) + if layer_ids is not None: + # Convert to tuple to make it hashable + factors.append(tuple(layer_ids)) + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @@ -352,6 +366,8 @@ def __post_init__(self): self.model = "ngram" elif self.method == "suffix": self.model = "suffix" + elif self.method == "extract_hidden_states": + self.model = "extract_hidden_states" else: raise ValueError( "num_speculative_tokens was provided but without speculative model." @@ -394,6 +410,34 @@ def __post_init__(self): self.draft_parallel_config = self.target_parallel_config elif self.method == "suffix": self._validate_suffix_decoding() + elif self.method == "extract_hidden_states": + from vllm.transformers_utils.configs.extract_hidden_states import ( + ExtractHiddenStatesConfig, + ) + + # ExtractHiddenStatesModel is instantiated manually in load_model() + # We just need to store the target model config for KV cache shape info + self.model = "extract_hidden_states" + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 + + if hasattr(self.draft_model_config, "hf_config"): + hf_config = self.draft_model_config.hf_config.to_dict() + elif ( + isinstance(self.draft_model_config, dict) + and "hf_config" in self.draft_model_config + ): + hf_config = self.draft_model_config["hf_config"] + else: + hf_config = {} + + self.draft_model_config = copy.copy(self.target_model_config) + self.draft_model_config.hf_config = ExtractHiddenStatesConfig( + self.draft_model_config.hf_config, **hf_config + ) + self.update_arch_() + self.draft_parallel_config = self.target_parallel_config + else: self.prompt_lookup_max = 0 self.prompt_lookup_min = 0 @@ -478,23 +522,8 @@ def __post_init__(self): method=self.method, model_type="eagle", ) - # EAGLEConfig primarily updates architectures, so update - # all architectures-related fields in draft_model_config self.draft_model_config.hf_config = eagle_config - self.draft_model_config.hf_text_config = get_hf_text_config( - self.draft_model_config.hf_config - ) - self.draft_model_config.model_arch_config = ( - self.draft_model_config.get_model_arch_config() - ) - model_info, arch = ( - self.draft_model_config.registry.inspect_model_cls( - self.draft_model_config.architectures, - self.draft_model_config, - ) - ) - self.draft_model_config._model_info = model_info - self.draft_model_config._architecture = arch + self.update_arch_() if self.num_speculative_tokens is not None and hasattr( self.draft_model_config.hf_config, "num_lookahead_tokens" @@ -671,6 +700,24 @@ def _verify_and_get_draft_tp( ) return speculative_draft_tensor_parallel_size + def update_arch_(self): + """ + EagleConfig and ExtractHiddenStatesConfig update architectures, so update all + architectures-related fields in self.draft_model_config + """ + self.draft_model_config.hf_text_config = get_hf_text_config( + self.draft_model_config.hf_config + ) + self.draft_model_config.model_arch_config = ( + self.draft_model_config.get_model_arch_config() + ) + model_info, arch = self.draft_model_config.registry.inspect_model_cls( + self.draft_model_config.architectures, + self.draft_model_config, + ) + self.draft_model_config._model_info = model_info + self.draft_model_config._architecture = arch + @staticmethod def create_draft_parallel_config( target_parallel_config: ParallelConfig, @@ -718,7 +765,7 @@ def _verify_args(self) -> Self: self.draft_parallel_config ) - eagle3_target_supported = [ + aux_hidden_states_supported = [ "llama", "qwen", "minicpm", @@ -729,16 +776,16 @@ def _verify_args(self) -> Self: "nemotron_h", ] if ( - self.method == "eagle3" + self.method in ("eagle3", "extract_hidden_states") and self.target_model_config and not any( supported_model in self.target_model_config.hf_text_config.model_type - for supported_model in eagle3_target_supported + for supported_model in aux_hidden_states_supported ) ): raise ValueError( - f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 - f"Got {self.target_model_config.hf_text_config.model_type=}" + f"{self.method} is only supported for {aux_hidden_states_supported}" + f" models. Got {self.target_model_config.hf_text_config.model_type=}" ) self.verify_equal_vocab_size_if_draft_model() return self @@ -782,8 +829,15 @@ def use_eagle(self) -> bool: def uses_draft_model(self) -> bool: return self.method == "draft_model" + def uses_extract_hidden_states(self) -> bool: + return self.method == "extract_hidden_states" + def __repr__(self) -> str: method = self.method - model = None if method in ("ngram", "suffix") else self.draft_model_config.model + model = ( + None + if method in ("ngram", "suffix", "extract_hidden_states") + else self.draft_model_config.model + ) num_spec_tokens = self.num_speculative_tokens return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 096ed4418546..21ec7a36e984 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -209,6 +209,10 @@ def get_number_of_workers(self) -> int: def clear_events(self) -> None: raise NotImplementedError + def merge(self, other: "KVConnectorKVEvents") -> "KVConnectorKVEvents": + self.add_events(other.get_all_events()) + return self + class EventPublisher(ABC): """Lightweight publisher for EventBatch batches with data parallelism diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 1ceac39711b2..d5a40fc639b4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -149,6 +149,12 @@ def get_connector_class( "ExampleConnector", ) +KVConnectorFactory.register_connector( + "ExampleHiddenStatesConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.example_hidden_states_connector", + "ExampleHiddenStatesConnector", +) + KVConnectorFactory.register_connector( "P2pNcclConnector", "vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py new file mode 100644 index 000000000000..945f8d9fd182 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Optional + +import safetensors +import torch + +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.logger import init_logger +from vllm.v1.attention.backend import AttentionMetadata +from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +def extract_from_kv_cache( + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + num_tokens: int, +) -> torch.Tensor: + """Extract data from KV cache + Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size) + """ + + padded_kv = kv_cache.flatten(0, 1)[slot_mapping] + # shape: [len(slot_mapping), num_heads, head_size] + return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size] + + +@dataclass +class ReqMeta: + # Request ID + req_id: str + # Request filename + filename: str + # Request tokens + token_ids: torch.Tensor + # Slot mappings, should have the same length as token_ids + slot_mapping: torch.Tensor + # Whether this request is a new request or partially computed already + new_req: bool + + @staticmethod + def make_meta( + req_id: str, + filename: str, + token_ids: list[int], + block_ids: list[int], + block_size: int, + new_req: bool, + ) -> "ReqMeta": + token_ids_tensor = torch.tensor(token_ids) + block_ids_tensor = torch.tensor(block_ids) + num_blocks = block_ids_tensor.shape[0] + block_offsets = torch.arange(0, block_size) + slot_mapping = ( + block_offsets.reshape((1, block_size)) + + block_ids_tensor.reshape((num_blocks, 1)) * block_size + ) + slot_mapping = slot_mapping.flatten() + return ReqMeta( + req_id=req_id, + filename=filename, + token_ids=token_ids_tensor, + slot_mapping=slot_mapping, + new_req=new_req, + ) + + +@dataclass +class ExampleHiddenStatesConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] = field(default_factory=list) + + def add_request( + self, + req_id: str, + filename: str, + token_ids: list[int], + block_ids: list[int], + block_size: int, + new_req: bool = True, + ) -> None: + self.requests.append( + ReqMeta.make_meta( + req_id, filename, token_ids, block_ids, block_size, new_req + ) + ) + + +class ExampleHiddenStatesConnector(KVConnectorBase_V1): + """ + Simple debug implementation of a HiddenStatesConnector. + + Simply extracts the hidden states from the kv cache and stores them to disk. + Must be used in conjunction with the `extract_hidden_states` spec decoding method. + """ + + @property + def prefer_cross_layer_blocks(self) -> bool: + """ + Indicates whether this connector prefers KV blocks that hold KV data for all + layers, which can speed up KV data transfers. Defaults to False. + """ + # Must be False so that drafter kv cache isn't merged with verifier's + return False + + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config, + ) + self._block_size = vllm_config.cache_config.block_size + self._storage_path = self._kv_transfer_config.get_from_extra_config( + "shared_storage_path", "/tmp" + ) + self.cache_layers: list[str] = [] # set by self.register_kv_caches + logger.info(self._kv_transfer_config) + logger.info("Shared storage path is %s", self._storage_path) + + assert self._vllm_config.speculative_config is not None, ( + "ExampleHiddenStatesConnector only works when using " + "'extract_hidden_states' speculative method" + ) + spec_config = self._vllm_config.speculative_config.draft_model_config.hf_config + self.num_hidden_states = len( + getattr(spec_config, "eagle_aux_hidden_state_layer_ids", []) + ) + + self._request_filenames: dict[str, str] = {} + self._active_requests: dict[str, NewRequestData] = {} + self._req_blocks: dict[str, list[int]] = {} + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, *args, **kwargs: Any) -> None: + pass # Empty implementation of abstract method + + def wait_for_layer_load(self, layer_name: str) -> None: + pass # Empty implementation of abstract method + + def wait_for_save(self): + pass # Empty implementation of abstract method + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + from vllm.model_executor.models.extract_hidden_states import ( + CacheOnlyAttentionLayer, + ) + + # Filter layers to only include CacheOnlyAttentionLayers + layers = get_layers_from_vllm_config( + self._vllm_config, CacheOnlyAttentionLayer, list(kv_caches.keys()) + ) + self.cache_layers = list(layers.keys()) + assert len(self.cache_layers) == 1, ( + f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}" + ) + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs: Any, + ) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + if layer_name not in self.cache_layers: + return + + from vllm.model_executor.models.extract_hidden_states import ( + CacheOnlyAttentionMetadata, + ) + + assert isinstance(attn_metadata, CacheOnlyAttentionMetadata), ( + "ExampleHiddenStatesConnector only supports CacheOnlyAttentionBackend" + ) + + connector_metadata = self._get_connector_metadata() + assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata) + + os.makedirs(self._storage_path, exist_ok=True) + for request in connector_metadata.requests: + hidden_states = extract_from_kv_cache( + kv_layer, request.slot_mapping, request.token_ids.shape[0] + ) + tensors = { + "hidden_states": hidden_states.detach().cpu(), + "token_ids": request.token_ids.detach().cpu(), + } + safetensors.torch.save_file(tensors, request.filename) + + # ============================== + # Scheduler-side methods + # ============================== + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + # This connector is store-only, so we don't need to load any tokens + return 0, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + # Usually used to handle allocation of new blocks for requests that are loading + # tokens from connector's external kv cache. We never load from external cache + # so this is a no-op. + assert num_external_tokens == 0, "This connector is store-only" + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = ExampleHiddenStatesConnectorMetadata() + for new_req in scheduler_output.scheduled_new_reqs: + token_ids = new_req.prompt_token_ids or [] + filename = os.path.join(self._storage_path, f"{new_req.req_id}.safetensors") + meta.add_request( + new_req.req_id, + filename=filename, + token_ids=token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + ) + self._request_filenames[new_req.req_id] = filename + self._active_requests[new_req.req_id] = new_req + self._req_blocks[new_req.req_id] = list(new_req.block_ids[0]) + + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + if req_id not in self._active_requests: + continue + + new_block_ids = cached_reqs.new_block_ids[i] + + cached_req = self._active_requests[req_id] + req_block_ids = self._req_blocks[req_id] + + assert new_block_ids is not None + block_ids = new_block_ids[0] + + req_block_ids.extend(block_ids) + filename = os.path.join(self._storage_path, f"{req_id}.safetensors") + + meta.add_request( + req_id=req_id, + filename=filename, + token_ids=cached_req.prompt_token_ids or [], + block_ids=req_block_ids, + block_size=self._block_size, + new_req=False, + ) + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Called exactly once when a request has finished, before its blocks are + freed. + + The connector may assumes responsibility for freeing the blocks + asynchronously by returning True. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + req_id = request.request_id + req_filename = self._request_filenames.pop(req_id, None) + _ = self._active_requests.pop(req_id, None) + _ = self._req_blocks.pop(req_id, None) + + return False, {"hidden_states_path": req_filename} + + @classmethod + def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: + """ + Get the required KV cache layout for this connector. + Args: + vllm_config (VllmConfig): the vllm config. + + Returns: + str: the required KV cache layout. e.g. HND, or NHD. + None if the connector does not require a specific layout. + """ + + if cls is KVConnectorBase_V1: + raise TypeError( + "get_required_kvcache_layout should not be called " + "on the abstract base class" + ) + # NHD means we have (num_tokens, num_heads) + # HND means we have (num_heads, num_tokens) + # For now, we only support NHD layout since this keeps the + # hidden states for each token together in memory. + # HND is primarily used when sharding heads across devices. + return "NHD" diff --git a/vllm/model_executor/models/extract_hidden_states.py b/vllm/model_executor/models/extract_hidden_states.py new file mode 100644 index 000000000000..ae9bdb5ed4e5 --- /dev/null +++ b/vllm/model_executor/models/extract_hidden_states.py @@ -0,0 +1,394 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Hidden States Extractor Model. + +This model extracts and caches hidden states from the target model +without performing actual token generation. It's used with the +extract_hidden_states speculative decoding method. +""" + +from collections.abc import Iterable +from typing import ClassVar + +import torch +import torch.nn as nn + +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.config.cache import CacheDType +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.attention.attention import set_default_quant_scales +from vllm.model_executor.layers.attention.kv_transfer_utils import ( + maybe_transfer_kv_layer, +) +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.models.utils import maybe_prefix +from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype +from vllm.v1.attention.backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadataBuilder, + AttentionType, + CommonAttentionMetadata, + is_quantized_kv_cache, +) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + KVCacheSpec, + MLAAttentionSpec, +) + +########## Custom Ops ######## + + +def unified_kv_cache_update( + to_cache: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + """ + Returns a dummy that is passed to unified_attention to signal a side effect and + the data dependency between them to ensure torch.compile preserves ordering. + """ + forward_context = get_forward_context() + attn_layer = forward_context.no_compile_layers[layer_name] + kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] + + slot_mapping = forward_context.slot_mapping + assert isinstance(slot_mapping, dict), ( + f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " + ) + layer_slot_mapping = slot_mapping.get(layer_name) + if layer_slot_mapping is not None: + assert hasattr(attn_layer.impl, "do_kv_cache_update"), ( + f"{attn_layer.impl.__class__.__name__} does not support kv cache update" + ) + attn_layer.impl.do_kv_cache_update( + attn_layer, + to_cache, + kv_cache, + layer_slot_mapping, + ) + + return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype) + + +@maybe_transfer_kv_layer +def dummy_attention(layer_name, _placeholder): + # Note: layer_name arg required by @maybe_transfer_kv_layer + return _placeholder + + +def basic_cache( + to_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size] + kv_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size] + slot_mapping: torch.Tensor, # shape: [seq_len] +): + num_blocks, block_size, num_heads, head_size = kv_cache.shape + token_kv_cache = kv_cache.view(num_blocks * block_size, num_heads, head_size) + token_kv_cache[slot_mapping] = to_cache + + +######### CacheOnlyAttentionBackend ######## + + +class CacheOnlyAttentionBackend(AttentionBackend): + """Attention backend that only caches KV without computing attention.""" + + accept_output_buffer: bool = False + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + ] + forward_includes_kv_cache_update: bool = False + + @staticmethod + def get_name() -> str: + return "CACHE_ONLY_ATTN" + + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + return attn_type == AttentionType.DECODER + + @classmethod + def supports_mm_prefix(cls) -> bool: + return True + + @staticmethod + def get_impl_cls() -> type["CacheOnlyAttentionImpl"]: + return CacheOnlyAttentionImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + # We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size` + # We also don't use a k/v (2) dim + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_builder_cls() -> type["CacheOnlyAttentionMetadataBuilder"]: + return CacheOnlyAttentionMetadataBuilder + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + + +class CacheOnlyAttentionMetadata: + def __init__(self, slot_mapping: torch.Tensor): + self.slot_mapping = slot_mapping + + +class CacheOnlyAttentionMetadataBuilder( + AttentionMetadataBuilder[CacheOnlyAttentionMetadata] +): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> CacheOnlyAttentionMetadata: + use_cascade = common_prefix_len > 0 + if use_cascade: + raise NotImplementedError( + "Cascade attention not supported by CacheOnlyAttention" + ) + causal = common_attn_metadata.causal + if not causal: + raise NotImplementedError( + "Non-causal attention not supported by CacheOnlyAttention" + ) + + return CacheOnlyAttentionMetadata( + slot_mapping=common_attn_metadata.slot_mapping, + ) + + +class CacheOnlyAttentionImpl(AttentionImpl): + """Attention implementation that only caches KV states.""" + + def __init__( + self, + num_heads: int, + head_size: int, + kv_cache_dtype: str, + kv_cache_torch_dtype: torch.dtype, + attn_type: AttentionType = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.kv_cache_dtype = kv_cache_dtype + self.kv_cache_torch_dtype = kv_cache_torch_dtype + + if attn_type != AttentionType.DECODER: + raise NotImplementedError(f"Unsupported attention type: {attn_type}") + if is_quantized_kv_cache(kv_cache_dtype): + raise NotImplementedError("Quantized KV cache not supported") + + self.num_queries_per_kv = 1 + + def do_kv_cache_update( + self, + layer, + to_cache, + kv_cache, + slot_mapping, + ): + assert to_cache.dtype == self.kv_cache_torch_dtype, ( + f"Data to cache must be {self.kv_cache_torch_dtype}, got {to_cache.dtype}" + ) + assert kv_cache.dtype == self.kv_cache_torch_dtype, ( + f"KV cache must be {self.kv_cache_torch_dtype}, got {kv_cache.dtype}" + ) + + basic_cache(to_cache, kv_cache, slot_mapping) + + def forward(self, *args, **kwargs): + # Empty implementation of abstract method + pass + + +############## CacheOnlyAttentionLayer (replaces Attention) ############ + + +class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase): + """Attention layer that only caches key/value states without computing attention.""" + + def __init__( + self, + num_heads: int, + head_size: int, + cache_config: CacheConfig | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ): + super().__init__() + + self.num_heads = num_heads + self.head_size = head_size + self.layer_name = prefix + + vllm_config = get_current_vllm_config() + + # KV cache configuration + cache_config = cache_config or vllm_config.cache_config + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + self.block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + self.block_size = 16 + + assert kv_cache_dtype in ["auto", "bfloat16", "float16"], ( + "CacheOnlyAttentionLayer doesn't currently support quantized kv cache but" + f"kv cache dtype was set to {kv_cache_dtype}" + ) + self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( + kv_cache_dtype, vllm_config.model_config + ) + + # Initialize KV cache quantization attributes + set_default_quant_scales(self, register_buffer=True) + + # Attention backend + self.attn_backend = CacheOnlyAttentionBackend + impl_cls = self.attn_backend.get_impl_cls() + self.impl = impl_cls( + num_heads, + head_size, + kv_cache_dtype, + self.kv_cache_torch_dtype, + attn_type, + ) + + assert not self.attn_backend.forward_includes_kv_cache_update, ( + "KV cache update should be independent of forward" + ) + + # Placeholder KV cache (replaced by bind_kv_cache) + self.kv_cache = [ + torch.tensor([]) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + + # Register in compilation context + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def forward(self, to_cache: torch.Tensor) -> torch.Tensor: + """Cache hidden states as KV pairs without computing attention. + + Args: + to_cache: The tensor to insert into the kv cache. + shape [num_tokens, num_heads, head_size] + + Returns: + Dummy output tensor (not used) + """ + # Note: we set num_heads to num_hidden_layers and + # head_size to hidden_size for hidden states storage + output = torch.empty(0, device=to_cache.device, dtype=to_cache.dtype) + + # Note: dummy_out is used to force torch.compile to preserve ordering between + # cache update and attention op (which triggers kv_connector transfer) + dummy_out = unified_kv_cache_update(to_cache, self.layer_name) + + # Triggers kv_connector transfer via decorator + _ = dummy_attention(self.layer_name, dummy_out) + + return output + + def get_attn_backend(self) -> type[AttentionBackend]: + return self.attn_backend + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # Note: we use MLAAttentionSpec here to because it will + # produce page sizes of (block_size * num_kv_heads * head_size * dtype_size) + # whereas FullAttentionSpec will add an additional factor of 2 + return MLAAttentionSpec( + block_size=self.block_size, + num_kv_heads=self.num_heads, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + ) + + +############ ExtractHiddenStatesModel definition ########## + + +class ExtractHiddenStatesModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.vllm_config = vllm_config + self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config + self.hidden_size = vllm_config.model_config.get_hidden_size() + self.target_num_hidden_layers = ( + vllm_config.model_config.get_total_num_hidden_layers() + ) + self.num_hidden_states = len( + getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", []) + ) + + cache_config = vllm_config.cache_config + + # Create a single cache-only attention layer + # Note: We set num_heads <- self.num_hidden_states + # and head_size <- hidden_size so that we can insert + # the hidden states directly into the cache without + # reshaping + self.cache_only_layers = nn.ModuleDict( + { + str(self.target_num_hidden_layers): CacheOnlyAttentionLayer( + num_heads=self.num_hidden_states, + head_size=self.hidden_size, + cache_config=cache_config, + prefix=maybe_prefix( + prefix, f"cache_only_layers.{self.target_num_hidden_layers}" + ), + ) + } + ) + + def forward(self, hidden_states: torch.Tensor) -> None: + """Process and cache hidden states. + + Args: + hidden_states: Hidden states from target model + shape: [num_tokens, num_hidden_states, hidden_size] + + Returns: + Tuple of (dummy_output, dummy_output) - both unused + """ + + # Call dummy attention layer to cache hidden states + # Output is ignored - we only care about the KV cache side effects + _ = self.cache_only_layers[str(self.target_num_hidden_layers)](hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """No weights to load for this dummy model.""" + return set() diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 75d656d499b5..97937e886216 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -512,6 +512,7 @@ } _SPECULATIVE_DECODING_MODELS = { + "ExtractHiddenStatesModel": ("extract_hidden_states", "ExtractHiddenStatesModel"), "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), diff --git a/vllm/transformers_utils/configs/extract_hidden_states.py b/vllm/transformers_utils/configs/extract_hidden_states.py new file mode 100644 index 000000000000..d5f5b3b47f71 --- /dev/null +++ b/vllm/transformers_utils/configs/extract_hidden_states.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Config definitions for ExtractHiddenStatesModel, to be used with +the extract_hidden_states spec decoding method.""" + +import os + +from transformers import PretrainedConfig + + +class ExtractHiddenStatesConfig(PretrainedConfig): + model_type = "extract_hidden_states" + + def __init__( + self, + model: PretrainedConfig | dict | None = None, + method: str | None = "extract_hidden_states", + **kwargs, + ): + assert method == "extract_hidden_states" + + if isinstance(model, dict): + model_dict = model + elif isinstance(model, PretrainedConfig): + model_dict = model.to_dict() + else: + model_dict = {} + + # Combine: model_dict first, then kwargs override + combined = {**model_dict, **kwargs} + # Remove architectures from the base, we'll set it explicitly + combined = {k: v for k, v in combined.items() if k != "architectures"} + + combined["architectures"] = ["ExtractHiddenStatesModel"] + + super().__init__(**combined) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str | os.PathLike, + **kwargs, + ) -> "ExtractHiddenStatesConfig": + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + return cls.from_dict(config_dict, **kwargs) + + def to_json_string(self, use_diff: bool = True) -> str: + # we override use_diff to False as initializing + # ExtractHiddenStatesConfig with default arguments is not supported + del use_diff + return super().to_json_string(use_diff=False) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index ad14bffcfc5c..22b06f0e2d97 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, NamedTuple, TypeAlias +from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypeVar import numpy as np import torch @@ -120,6 +121,20 @@ class SamplerOutput: logprobs_tensors: LogprobsTensors | None +T = TypeVar("T") + + +def _combine_non_none(f: Callable[[T, T], T], items: list[T | None]) -> T | None: + non_none = [item for item in items if item is not None] + if len(non_none) == 0: + return None + + combined = non_none[0] + for item in non_none[1:]: + combined = f(combined, item) + return combined + + @dataclass class KVConnectorOutput: # [req_ids] @@ -146,6 +161,43 @@ def is_empty(self): and not self.invalid_block_ids ) + @classmethod + def merge(cls, *outputs: "KVConnectorOutput"): + assert len(outputs) > 0, "Cannot merge empty outputs" + finished_sending = _combine_non_none( + set.union, [output.finished_sending for output in outputs] + ) + finished_recving = _combine_non_none( + set.union, [output.finished_recving for output in outputs] + ) + kv_connector_stats = _combine_non_none( + lambda x, y: x.aggregate(y), + [output.kv_connector_stats for output in outputs], + ) + kv_cache_events = _combine_non_none( + lambda x, y: x.merge(y), + [output.kv_cache_events for output in outputs], + ) + invalid_block_ids = _combine_non_none( + set.union, [output.invalid_block_ids for output in outputs] + ) + assert invalid_block_ids is not None + + assert all( + output.expected_finished_count == outputs[0].expected_finished_count + for output in outputs + ) + expected_finished_count = outputs[0].expected_finished_count + + return cls( + finished_sending=finished_sending, + finished_recving=finished_recving, + kv_connector_stats=kv_connector_stats, + kv_cache_events=kv_cache_events, + invalid_block_ids=invalid_block_ids, + expected_finished_count=expected_finished_count, + ) + @dataclass class ECConnectorOutput: diff --git a/vllm/v1/spec_decode/extract_hidden_states.py b/vllm/v1/spec_decode/extract_hidden_states.py new file mode 100644 index 000000000000..38a54f01696c --- /dev/null +++ b/vllm/v1/spec_decode/extract_hidden_states.py @@ -0,0 +1,395 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn + +from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config +from vllm.distributed.kv_transfer import has_kv_transfer_group +from vllm.forward_context import set_forward_context +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.model_loader import get_model +from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +from vllm.v1.outputs import KVConnectorOutput +from vllm.v1.worker.dp_utils import coordinate_batch_across_dp +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.kv_cache_interface import KVCacheConfig + +PADDING_SLOT_ID = -1 + + +class ExtractHiddenStatesProposer: + def __init__(self, vllm_config: VllmConfig, device): + assert vllm_config.speculative_config is not None + + assert vllm_config.speculative_config.num_speculative_tokens == 1 + if vllm_config.speculative_config.disable_padded_drafter_batch: + raise ValueError( + "disable_padded_drafter_batch is not supported with " + "extract_hidden_states method" + ) + self.vllm_config = vllm_config + self.device = device + self.dtype = vllm_config.model_config.dtype + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + + # Model and attention layer tracking (initialized in load_model) + self.model: nn.Module | None = None + self.attn_layer_names: list[str] = [] + self.attn_metadata_builder: AttentionMetadataBuilder | None = None + + # Maximum number of tokens for buffers + max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.max_num_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size + ) + + self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config + layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None) + if not layer_ids: + raise ValueError( + "eagle_aux_hidden_state_layer_ids must be set in the draft " + "model config for extract_hidden_states method" + ) + self.num_hidden_states = len(layer_ids) + self.hidden_size = vllm_config.model_config.get_hidden_size() + self.hidden_states = torch.zeros( + (self.max_num_tokens, self.num_hidden_states, self.hidden_size), + dtype=self.dtype, + device=device, + ) + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) + + self._slot_mapping_buffer = torch.zeros( + self.max_num_tokens, dtype=torch.int64, device=device + ) + + def propose( + self, + sampled_token_ids: torch.Tensor, + target_hidden_states: list[torch.Tensor], + common_attn_metadata: CommonAttentionMetadata, + scheduler_output: SchedulerOutput, + slot_mappings: dict[str, torch.Tensor] + | list[dict[str, torch.Tensor]] + | None = None, + ) -> tuple[torch.Tensor, KVConnectorOutput | None]: + """Propose draft tokens by calling the ExtractHiddenStatesModel model. + + The ExtractHiddenStatesModel caches the hidden states in the KV cache + without performing actual attention computation. This allows us to + extract and store hidden states for later use (e.g., KV transfer). + + This proposer doesn't actually perform speculation - it returns the + sampled tokens as "draft" tokens, ensuring they always verify (match). + The main purpose is to cache hidden states, not to speculate. + + Args: + sampled_token_ids: Sampled token IDs from the target model + target_hidden_states: List of hidden state tensors from target model + (one per aux hidden state layer) + common_attn_metadata: Attention metadata + scheduler_output: Scheduler output for KV connector + slot_mappings: Slot mappings for KV cache (unused, provided for + interface compatibility) + + Returns: + Tuple of: + - Draft tokens matching sampled tokens, shape [batch_size, 1] + - KV connector output (if KV transfer is active), else None + """ + assert self.model is not None and isinstance(target_hidden_states, list) + + # target_hidden_states is a list of tensors (one per layer) + # Each tensor has shape [num_tokens, hidden_size] + # Stack to shape: [num_tokens, num_hidden_states, hidden_size] + stacked_hidden_states = torch.stack(target_hidden_states, dim=1) + num_tokens = stacked_hidden_states.shape[0] + + # Copy hidden states to buffer + self.hidden_states[:num_tokens] = stacked_hidden_states + + assert self.attn_metadata_builder is not None + attn_metadata = self.attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=0 + ) + + # We assume all cache-only layers belong to the same KV cache group, + # thus using the same attention metadata. + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + + cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( + self._determine_batch_execution_and_padding(num_tokens) + ) + if num_tokens_across_dp is not None: + num_tokens_across_dp[self.dp_rank] = num_input_tokens + + with ( + set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=self._get_slot_mapping( + num_input_tokens, common_attn_metadata.slot_mapping + ), + ), + ( + KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output) + if has_kv_transfer_group() + else nullcontext() + ) as kv_connector_output, + ): + self.model( + hidden_states=self.hidden_states[:num_input_tokens], + ) + + # Return the sampled tokens as "draft" tokens + # Shape: [batch_size, 1] to match num_speculative_tokens=1 + return sampled_token_ids.unsqueeze(-1), kv_connector_output + + def _get_slot_mapping( + self, + num_tokens: int, + slot_mapping: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Return slot_mapping dict for cache-only attention layers. + + If slot_mapping is provided, copies it into the buffer first. + """ + if slot_mapping is not None: + num_actual = slot_mapping.shape[0] + self._slot_mapping_buffer[:num_actual].copy_(slot_mapping) + if num_tokens > num_actual: + self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID) + + view = self._slot_mapping_buffer[:num_tokens] + return {name: view for name in self.attn_layer_names} + + def _determine_batch_execution_and_padding( + self, + num_tokens: int, + use_cudagraphs: bool = True, + ) -> tuple[CUDAGraphMode, int, torch.Tensor | None]: + cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens, + valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None), + ) + num_tokens_padded = batch_desc.num_tokens + + # Extra coordination when running data-parallel since we need to + # coordinate across ranks + # TODO(Flechman): support DBO ubatching + should_ubatch, num_tokens_across_dp = False, None + if self.vllm_config.parallel_config.data_parallel_size > 1: + should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = ( + coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens, + parallel_config=self.vllm_config.parallel_config, + allow_microbatching=False, + num_tokens_padded=num_tokens_padded, + cudagraph_mode=cudagraph_mode.value, + ) + ) + assert not should_ubatch, ( + "DBO ubatching not implemented for extract_hidden_states" + ) + + # Extract DP-synced values + if num_tokens_across_dp is not None: + dp_rank = self.dp_rank + num_tokens_padded = int(num_tokens_across_dp[dp_rank].item()) + # Re-dispatch with DP padding so we have the correct + # batch_descriptor + cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens_padded, + valid_modes={CUDAGraphMode(synced_cudagraph_mode)}, + ) + # Assert to make sure the agreed upon token count is correct + # otherwise num_tokens_across_dp will no-longer be valid + assert batch_desc.num_tokens == num_tokens_padded + num_tokens_across_dp[dp_rank] = num_tokens_padded + + return cudagraph_mode, num_tokens_padded, num_tokens_across_dp + + def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: + """Initialize cudagraph dispatcher keys. + + Only supports PIECEWISE cudagraphs (via mixed_mode). + Should be called after adjust_cudagraph_sizes_for_spec_decode. + """ + assert self.vllm_config.speculative_config is not None + if ( + not self.vllm_config.speculative_config.enforce_eager + and cudagraph_mode.mixed_mode() + in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL] + ): + proposer_cudagraph_mode = CUDAGraphMode.PIECEWISE + else: + proposer_cudagraph_mode = CUDAGraphMode.NONE + + self.cudagraph_dispatcher.initialize_cudagraph_keys(proposer_cudagraph_mode) + + @torch.inference_mode() + def dummy_run( + self, + num_tokens: int, + use_cudagraphs: bool = True, + is_graph_capturing: bool = False, + slot_mappings: dict[str, torch.Tensor] | None = None, + ) -> None: + assert self.model is not None, "Model must be initialized before dummy_run" + cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( + self._determine_batch_execution_and_padding( + num_tokens, use_cudagraphs=use_cudagraphs + ) + ) + + if num_tokens_across_dp is not None: + num_tokens_across_dp[self.dp_rank] = num_input_tokens + + # Use our own slot mapping buffer during cudagraph capture. + if ( + self.attn_layer_names + and slot_mappings is not None + and self.attn_layer_names[0] in slot_mappings + ): + slot_mapping_dict = self._get_slot_mapping(num_input_tokens) + else: + slot_mapping_dict = slot_mappings or {} + + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=slot_mapping_dict, + ): + self.model( + hidden_states=self.hidden_states[:num_input_tokens], + ) + + def _build_attn_metadata_builder( + self, draft_attn_layers: dict[str, AttentionLayerBase] + ) -> AttentionMetadataBuilder: + """Build the attention metadata builder from draft attention layers.""" + if not draft_attn_layers: + raise ValueError("No attention layers found for ExtractHiddenStatesModel") + layer = next(iter(draft_attn_layers.values())) + attn_backend = layer.get_attn_backend() + return attn_backend.get_builder_cls()( + layer.get_kv_cache_spec(self.vllm_config), + self.attn_layer_names, + self.vllm_config, + self.device, + ) + + def prepare_next_token_ids_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Prepare next token IDs for speculative decoding. + + Since num_speculative_tokens == 1, sampled_token_ids has shape + (batch_size, 1). For each request we either use the sampled token + (if valid and not discarded) or a backup token from the request state. + """ + num_reqs = gpu_input_batch.num_reqs + device = sampled_token_ids.device + + # Compute backup tokens for discarded / invalid requests + backup_tokens_gpu = torch.tensor( + [ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item() + ) + for i in range(num_reqs) + ], + dtype=torch.int32, + device=device, + ) + + assert discard_request_mask.dtype == torch.bool + + # With num_speculative_tokens == 1, there is exactly one token + sampled = sampled_token_ids[:, 0] + is_valid = (sampled >= 0) & (sampled < gpu_input_batch.vocab_size) + valid_sampled_tokens_count = is_valid.to(torch.int32) + + use_sampled = is_valid & ~discard_request_mask[:num_reqs] + next_token_ids = torch.where( + use_sampled, sampled.to(torch.int32), backup_tokens_gpu + ) + + return next_token_ids, valid_sampled_tokens_count + + def load_model(self, target_model: nn.Module) -> None: + """Load the ExtractHiddenStatesModel model. + + This method instantiates the ExtractHiddenStatesModel model which is used + to cache hidden states during speculative decoding. The model uses + cache-only attention (no computation, just caching KV states). + + Args: + target_model: The target model (passed for compatibility with + EagleProposer interface, but not used here) + """ + # Get the target model's attention layers before loading draft model + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() # type: ignore[type-abstract] + ) + + assert self.vllm_config.speculative_config is not None + draft_model_config = self.vllm_config.speculative_config.draft_model_config + from vllm.compilation.backends import set_model_tag + + with set_model_tag("extract_hidden_states"): + self.model = get_model( + vllm_config=self.vllm_config, model_config=draft_model_config + ) + + # Identify draft model's attention layers (difference from target) + all_attn_layers = get_layers_from_vllm_config( + self.vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ) + draft_attn_layers = { + name: layer + for name, layer in all_attn_layers.items() + if name not in target_attn_layer_names + } + self.attn_layer_names = list(draft_attn_layers.keys()) + assert len(draft_attn_layers) == 1, ( + "ExtractHiddenStatesModel should have exactly one " + f"attention layer, found {len(draft_attn_layers)}" + ) + self.attn_metadata_builder = self._build_attn_metadata_builder( + draft_attn_layers + ) + + def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: + """Validate all drafting layers belong to the same KV cache group. + + With exactly one attention layer (asserted in load_model), this is + trivially satisfied. + """ + assert len(self.attn_layer_names) == 1 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 36abee66ea86..c99d8f164546 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -159,6 +159,7 @@ from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer @@ -495,6 +496,7 @@ def __init__( | EagleProposer | DraftModelProposer | MedusaProposer + | ExtractHiddenStatesProposer ) if self.speculative_config.method == "ngram": from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -518,6 +520,11 @@ def __init__( self.drafter = MedusaProposer( vllm_config=self.vllm_config, device=self.device ) + elif self.speculative_config.method == "extract_hidden_states": + self.drafter = ExtractHiddenStatesProposer( + vllm_config=self.vllm_config, device=self.device + ) + self.use_aux_hidden_state_outputs = True else: raise ValueError( "Unknown speculative decoding method: " @@ -3693,10 +3700,9 @@ def execute_model( def sample_tokens( self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: - kv_connector_output = self.kv_connector_output - self.kv_connector_output = None - if self.execute_model_state is None: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None # receive sampled token ids from the last PP rank. if self.use_async_scheduling and get_pp_group().world_size > 1: self._pp_receive_prev_sampled_token_ids_to_input_batch() @@ -3778,12 +3784,17 @@ def propose_draft_token_ids(sampled_token_ids): <= self.effective_drafter_max_model_len ) use_gpu_toks = ( - spec_config.use_eagle() or spec_config.uses_draft_model() + spec_config.use_eagle() + or spec_config.uses_draft_model() + or spec_config.uses_extract_hidden_states() ) and not spec_config.disable_padded_drafter_batch if use_gpu_toks: # EAGLE/DraftModel speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. - assert isinstance(self.drafter, EagleProposer | DraftModelProposer) + assert isinstance( + self.drafter, + EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, + ) sampled_token_ids = sampler_output.sampled_token_ids if input_fits_in_drafter: propose_draft_token_ids(sampled_token_ids) @@ -3842,6 +3853,10 @@ def propose_draft_token_ids(sampled_token_ids): with record_function_or_nullcontext("gpu_model_runner: eplb"): self.eplb_step() + # self.kv_connector_output may be modified during drafting + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): if self.model_config.enable_return_routed_experts: capturer = RoutedExpertsCapturer.get_instance() @@ -4068,6 +4083,48 @@ def propose_draft_token_ids( sampling_metadata=sampling_metadata, slot_mappings=slot_mappings, ) + elif spec_config.uses_extract_hidden_states(): + assert isinstance(self.drafter, ExtractHiddenStatesProposer) + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor for " + "extract_hidden_states method." + ) + if not self.use_aux_hidden_state_outputs or aux_hidden_states is None: + raise ValueError( + "aux_hidden_states are required when using `extract_hidden_states`" + ) + target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states] + + draft_token_ids, drafter_kv_connector_output = self.drafter.propose( + sampled_token_ids=sampled_token_ids, + target_hidden_states=target_hidden_states, + common_attn_metadata=common_attn_metadata, + scheduler_output=scheduler_output, + slot_mappings=slot_mappings, + ) + # Combine KVConnectorOutputs or select the non-empty one + if self.kv_connector_output and drafter_kv_connector_output: + self.kv_connector_output = KVConnectorOutput.merge( + self.kv_connector_output, drafter_kv_connector_output + ) + else: + self.kv_connector_output = ( + self.kv_connector_output or drafter_kv_connector_output + ) + + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + elif spec_config.use_eagle() or spec_config.uses_draft_model(): assert isinstance(self.drafter, EagleProposer | DraftModelProposer) @@ -4946,8 +5003,12 @@ def _dummy_run( if self.speculative_config and ( self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model() + or self.speculative_config.uses_extract_hidden_states() ): - assert isinstance(self.drafter, EagleProposer | DraftModelProposer) + assert isinstance( + self.drafter, + EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, + ) assert self.speculative_config is not None # Eagle currently only supports PIECEWISE cudagraphs. # Therefore only use cudagraphs if the main model uses PIECEWISE @@ -5656,9 +5717,12 @@ def _check_and_update_cudagraph_mode( cudagraph_mode, self.uniform_decode_query_len ) - # Initialize eagle's cudagraph dispatcher if using eagle spec decode. - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + # Initialize drafter's cudagraph dispatcher if using spec decode. + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_extract_hidden_states() + ): + assert isinstance(self.drafter, EagleProposer | ExtractHiddenStatesProposer) self.drafter.initialize_cudagraph_keys(cudagraph_mode) def calculate_reorder_batch_threshold(self) -> None: @@ -6025,8 +6089,12 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if self.speculative_config and ( self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model() + or self.speculative_config.uses_extract_hidden_states() ): - assert isinstance(self.drafter, EagleProposer | DraftModelProposer) + assert isinstance( + self.drafter, + EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, + ) # validate all draft model layers belong to the same kv cache # group self.drafter.validate_same_kv_cache_group(kv_cache_config)