diff --git a/.buildkite/test_areas/misc.yaml b/.buildkite/test_areas/misc.yaml index 2a78201a9e47..7605f08397fa 100644 --- a/.buildkite/test_areas/misc.yaml +++ b/.buildkite/test_areas/misc.yaml @@ -127,6 +127,19 @@ steps: - pytest -v -s -m 'cpu_test' v1/kv_connector/unit - pytest -v -s -m 'cpu_test' v1/metrics +- label: Extract Hidden States Integration + key: extract-hidden-states-integration + timeout_in_minutes: 20 + device: h200_18gb + source_file_dependencies: + - vllm/v1/spec_decode/extract_hidden_states.py + - vllm/model_executor/models/extract_hidden_states.py + - vllm/transformers_utils/configs/extract_hidden_states.py + - tests/v1/kv_connector/extract_hidden_states_integration + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s v1/kv_connector/extract_hidden_states_integration + - label: Regression key: regression timeout_in_minutes: 20 diff --git a/tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py b/tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py index 6a8c64152fec..1426beb9d060 100644 --- a/tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py +++ b/tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py @@ -83,7 +83,7 @@ def register_predictable_model(): def test_extract_hidden_states_with_predictable_dummy_model( - predictable_llama_config_path, tmp_path + predictable_llama_config_path, tmp_path, monkeypatch ): """Comprehensive test using a predictable dummy model with synthetic weights. @@ -94,6 +94,12 @@ def test_extract_hidden_states_with_predictable_dummy_model( 3. Layer ordering is preserved correctly (non-sequential layer IDs) 4. Multiple prompts of different lengths produce consistent layer values """ + # Force fork so the engine worker inherits the autouse fixture's + # ModelRegistry.register_model("PredictableLlamaForCausalLM", ...). + # Spawn (the CI default) starts a fresh Python process that wouldn't + # see the registration. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork") + # Test with non-sequential layer ordering to verify correct association layer_ids = [5, 2, 10] num_layers = len(layer_ids) @@ -153,3 +159,55 @@ def test_extract_hidden_states_with_predictable_dummy_model( f"but got mean={layer_hidden.mean():.3f}, " f"min={layer_hidden.min():.3f}, max={layer_hidden.max():.3f}" ) + + +def test_extract_hidden_states_qwen35_hybrid_smoke(tmp_path): + """Smoke test for Qwen3.5 hybrid (mamba + full-attention) models. + Uses load_format="dummy" to just check shape/plumbing. + """ + layer_ids = [5, 11, 17] + hidden_size = 1024 # Qwen/Qwen3.5-0.8B hidden_size + + llm = LLM( + model="Qwen/Qwen3.5-0.8B", + speculative_config={ + "method": "extract_hidden_states", + "num_speculative_tokens": 1, + "draft_model_config": { + "hf_config": {"eagle_aux_hidden_state_layer_ids": layer_ids} + }, + }, + kv_transfer_config={ + "kv_connector": "ExampleHiddenStatesConnector", + "kv_role": "kv_producer", + "kv_connector_extra_config": {"shared_storage_path": str(tmp_path)}, + }, + max_model_len=256, + enforce_eager=True, + gpu_memory_utilization=0.4, + load_format="dummy", + ) + + prompts = ["Hello world", "Test prompt with several tokens"] + sampling_params = SamplingParams(max_tokens=1, temperature=0.0) + outputs = llm.generate(prompts, sampling_params) + del llm + gc.collect() + + assert len(outputs) == len(prompts) + for output in outputs: + assert output.kv_transfer_params is not None + hidden_states_path = output.kv_transfer_params.get("hidden_states_path") + assert hidden_states_path is not None + assert os.path.exists(hidden_states_path) + + with safe_open(hidden_states_path, "pt") as f: + token_ids = f.get_tensor("token_ids") + hidden_states = f.get_tensor("hidden_states") + + assert torch.equal(token_ids, torch.tensor(output.prompt_token_ids)) + assert hidden_states.shape == ( + len(output.prompt_token_ids), + len(layer_ids), + hidden_size, + ) diff --git a/tests/v1/kv_connector/unit/test_decode_bench_connector.py b/tests/v1/kv_connector/unit/test_decode_bench_connector.py index 3af58d63c9a1..5f7c5eeefad0 100644 --- a/tests/v1/kv_connector/unit/test_decode_bench_connector.py +++ b/tests/v1/kv_connector/unit/test_decode_bench_connector.py @@ -11,7 +11,6 @@ import torch from vllm import SamplingParams -from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole # ruff: noqa: E501 @@ -44,11 +43,9 @@ def __init__(self, block_size: int, num_gpu_blocks: int): # Create vllm config with DecodeBenchConnector vllm_config = create_vllm_config( - block_size=block_size, max_num_batched_tokens=1000 - ) - vllm_config.kv_transfer_config = KVTransferConfig( + block_size=block_size, + max_num_batched_tokens=1000, kv_connector="DecodeBenchConnector", - kv_role="kv_both", ) self.vllm_config = vllm_config diff --git a/tests/v1/kv_connector/unit/test_kv_connector_lifecycle.py b/tests/v1/kv_connector/unit/test_kv_connector_lifecycle.py index a9a38a17b949..41456f772e57 100644 --- a/tests/v1/kv_connector/unit/test_kv_connector_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_kv_connector_lifecycle.py @@ -33,10 +33,11 @@ def _make_empty_scheduler_output(): def test_kv_connector_mixin_clears_metadata(): - vllm_config = create_vllm_config() - vllm_config.kv_transfer_config.kv_connector = "TestExampleConnector" - vllm_config.kv_transfer_config.kv_role = "kv_both" - vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit" + vllm_config = create_vllm_config( + kv_connector="TestExampleConnector", + kv_role="kv_both", + kv_connector_extra_config={"name": "unit"}, + ) # Initialize the global connector instance kv_cache_config = KVCacheConfig( diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index f0aa4e260f60..2747ae3496b4 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -130,17 +130,14 @@ def request_finished_all_groups(self, request, block_ids): @pytest.fixture def mc() -> MultiConnector: """MultiConnector using two mocked connectors""" - vllm_config = create_vllm_config() - mock_connector_config = { "kv_connector": "MockConnector", "kv_role": "kv_both", "kv_connector_module_path": "tests.v1.kv_connector.unit.test_multi_connector", } - vllm_config.kv_transfer_config = KVTransferConfig( + vllm_config = create_vllm_config( kv_connector="MultiConnector", - kv_role="kv_both", kv_connector_extra_config={ "connectors": [mock_connector_config, mock_connector_config], }, @@ -403,39 +400,35 @@ def test_multi_connector_handle_preemptions_integration(): try: # Configure MultiConnector with two TestExampleConnectors - kv_transfer_config = KVTransferConfig( - kv_connector="MultiConnector", - kv_role="kv_both", - kv_connector_extra_config={ - "connectors": [ - { - "kv_connector": "TestExampleConnector", - "kv_role": "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": str(storage_path / "s1"), - "name": "preempt1", - }, - "kv_connector_module_path": "tests.v1.kv_connector.unit.utils", + connectors_extra_config = { + "connectors": [ + { + "kv_connector": "TestExampleConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_path / "s1"), + "name": "preempt1", }, - { - "kv_connector": "TestExampleConnector", - "kv_role": "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": str(storage_path / "s2"), - "name": "preempt2", - }, - "kv_connector_module_path": "tests.v1.kv_connector.unit.utils", + "kv_connector_module_path": "tests.v1.kv_connector.unit.utils", + }, + { + "kv_connector": "TestExampleConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_path / "s2"), + "name": "preempt2", }, - ] - }, - ) + "kv_connector_module_path": "tests.v1.kv_connector.unit.utils", + }, + ] + } vllm_config = create_vllm_config( block_size=16, max_num_batched_tokens=100, - kv_connector_extra_config=kv_transfer_config.kv_connector_extra_config, + kv_connector="MultiConnector", + kv_connector_extra_config=connectors_extra_config, ) - vllm_config.kv_transfer_config = kv_transfer_config # Create scheduler - this initializes the MultiConnector with SCHEDULER role scheduler = create_scheduler(vllm_config, num_blocks=10) @@ -971,7 +964,6 @@ def assert_update_connector_output_called(mc: MultiConnector): def _make_multi_connector(connector_names: list[str]) -> MultiConnector: """Build a MultiConnector wrapping the given registered connectors.""" - vllm_config = create_vllm_config() connectors = [ { "kv_connector": name, @@ -980,9 +972,8 @@ def _make_multi_connector(connector_names: list[str]) -> MultiConnector: } for name in connector_names ] - vllm_config.kv_transfer_config = KVTransferConfig( + vllm_config = create_vllm_config( kv_connector="MultiConnector", - kv_role="kv_both", kv_connector_extra_config={"connectors": connectors}, ) kv_cache_config = KVCacheConfig( diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 5db2f7e7a919..1b892849d909 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -102,6 +102,7 @@ def create_vllm_config( attention_backend: str | None = None, kv_load_failure_policy: Literal["recompute", "fail"] = "fail", kv_connector: str = "NixlConnector", + kv_connector_module_path: str | None = None, kv_role: str = "kv_both", disable_hybrid_kv_cache_manager: bool | None = None, ) -> VllmConfig: @@ -130,6 +131,7 @@ def create_vllm_config( ) kv_transfer_config = KVTransferConfig( kv_connector=kv_connector, + kv_connector_module_path=kv_connector_module_path, kv_role=kv_role, enable_permute_local_kv=enable_permute_local_kv, kv_connector_extra_config=kv_connector_extra_config or {}, diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index d220aa65035d..3d08bc088799 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1316,6 +1316,10 @@ def has_blocked_weights(): "the `reasoning_start_str` and `reasoning_end_str`." ) + # Resolve kv_offloading-derived connector name into kv_transfer_config + # before the HMA check below, which inspects the connector class. + self._post_init_kv_transfer_config() + # Hybrid KV cache manager (HMA) runtime rules: # - Explicit enable (--no-disable-kv-cache-manager): error if runtime # disables it @@ -1353,18 +1357,42 @@ def has_blocked_weights(): if self.scheduler_config.disable_hybrid_kv_cache_manager is None: # Default to disable HMA, but only if the user didn't express a preference. if self.kv_transfer_config is not None: - # NOTE(Kuntai): turn HMA off for connector unless specifically enabled. - need_disable_hybrid_kv_cache_manager = True - logger.warning( - "Turning off hybrid kv cache manager because " - "`--kv-transfer-config` is set. This will reduce the " - "performance of vLLM on LLMs with sliding window attention " - "or Mamba attention. If you are a developer of kv connector" - ", please consider supporting hybrid kv cache manager for " - "your connector by making sure your connector is a subclass" - " of `SupportsHMA` defined in kv_connector/v1/base.py and" - " use --no-disable-hybrid-kv-cache-manager to start vLLM." + from vllm.config.kv_transfer import KVTransferConfig + from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory, + ) + from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + supports_hma, + ) + + connector_cls = KVConnectorFactory.get_connector_class( + self.kv_transfer_config ) + all_support_hma = supports_hma(connector_cls) + # MultiConnector subclasses SupportsHMA; only effectively + # supports HMA when every sub-connector does. + if all_support_hma and connector_cls.__name__ == "MultiConnector": + sub_ktcs = self.kv_transfer_config.kv_connector_extra_config.get( + "connectors", [] + ) + all_support_hma = all( + supports_hma( + KVConnectorFactory.get_connector_class( + KVTransferConfig(**sub) + ) + ) + for sub in sub_ktcs + ) + if not all_support_hma: + need_disable_hybrid_kv_cache_manager = True + logger.warning( + "Turning off hybrid kv cache manager because " + "connector %s does not subclass `SupportsHMA`. " + "This will reduce performance on models with " + "sliding window or Mamba attention. See " + "kv_connector/v1/base.py for details.", + connector_cls.__name__, + ) self.scheduler_config.disable_hybrid_kv_cache_manager = ( need_disable_hybrid_kv_cache_manager ) @@ -1406,10 +1434,7 @@ def has_blocked_weights(): if "-quant_fp8" not in custom_ops: custom_ops.append("+quant_fp8") - # Handle the KV connector configs - self._post_init_kv_transfer_config() self._verify_kv_transfer_compat() - # Log the custom passes that are enabled self.compilation_config.pass_config.log_enabled_passes() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py index bf56db32e4f8..98b03c4ebb12 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py @@ -12,6 +12,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + SupportsHMA, ) from vllm.logger import init_logger from vllm.v1.attention.backend import AttentionMetadata @@ -30,13 +31,9 @@ def extract_from_kv_cache( slot_mapping: torch.Tensor, num_tokens: int, ) -> torch.Tensor: - """Extract data from KV cache - Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size) - """ - - padded_kv = kv_cache.flatten(0, 1)[slot_mapping] - # shape: [len(slot_mapping), num_heads, head_size] - return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size] + """Extract data from KV cache.""" + block_size = kv_cache.shape[1] + return kv_cache[slot_mapping // block_size, slot_mapping % block_size][:num_tokens] @dataclass @@ -47,8 +44,6 @@ class ReqMeta: filename: str # Request tokens token_ids: torch.Tensor - # Slot mappings, should have the same length as token_ids - slot_mapping: torch.Tensor # Whether this request is a new request or partially computed already new_req: bool @@ -57,24 +52,12 @@ def make_meta( req_id: str, filename: str, token_ids: list[int], - block_ids: list[int], - block_size: int, new_req: bool, ) -> "ReqMeta": - token_ids_tensor = torch.tensor(token_ids) - block_ids_tensor = torch.tensor(block_ids) - num_blocks = block_ids_tensor.shape[0] - block_offsets = torch.arange(0, block_size) - slot_mapping = ( - block_offsets.reshape((1, block_size)) - + block_ids_tensor.reshape((num_blocks, 1)) * block_size - ) - slot_mapping = slot_mapping.flatten() return ReqMeta( req_id=req_id, filename=filename, - token_ids=token_ids_tensor, - slot_mapping=slot_mapping, + token_ids=torch.tensor(token_ids), new_req=new_req, ) @@ -88,18 +71,12 @@ def add_request( req_id: str, filename: str, token_ids: list[int], - block_ids: list[int], - block_size: int, new_req: bool = True, ) -> None: - self.requests.append( - ReqMeta.make_meta( - req_id, filename, token_ids, block_ids, block_size, new_req - ) - ) + self.requests.append(ReqMeta.make_meta(req_id, filename, token_ids, new_req)) -class ExampleHiddenStatesConnector(KVConnectorBase_V1): +class ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA): """ Simple debug implementation of a HiddenStatesConnector. @@ -206,9 +183,16 @@ def save_kv_layer( assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata) os.makedirs(self._storage_path, exist_ok=True) + + slot_mapping = attn_metadata.slot_mapping + offset = 0 for request in connector_metadata.requests: + num_tokens = request.token_ids.shape[0] + req_slot_mapping = slot_mapping[offset : offset + num_tokens] + offset += num_tokens + hidden_states = extract_from_kv_cache( - kv_layer, request.slot_mapping, request.token_ids.shape[0] + kv_layer, req_slot_mapping, num_tokens ) tensors = { "hidden_states": hidden_states.detach().cpu(), @@ -269,8 +253,6 @@ def build_connector_meta( new_req.req_id, filename=filename, token_ids=token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size, ) self._request_filenames[new_req.req_id] = filename self._active_requests[new_req.req_id] = new_req @@ -298,8 +280,6 @@ def build_connector_meta( req_id=req_id, filename=filename, token_ids=cached_req.prompt_token_ids or [], - block_ids=req_block_ids, - block_size=self._block_size, new_req=False, ) @@ -331,6 +311,13 @@ def request_finished( return False, {"hidden_states_path": req_filename} + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + return self.request_finished(request, block_ids[0]) + @classmethod def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: """ diff --git a/vllm/model_executor/models/extract_hidden_states.py b/vllm/model_executor/models/extract_hidden_states.py index e2d39ead66d9..8df4823b6973 100644 --- a/vllm/model_executor/models/extract_hidden_states.py +++ b/vllm/model_executor/models/extract_hidden_states.py @@ -34,8 +34,8 @@ ) from vllm.v1.kv_cache_interface import ( AttentionSpec, + HiddenStateCacheSpec, KVCacheSpec, - MLAAttentionSpec, ) ########## Custom Ops ######## @@ -79,13 +79,12 @@ def dummy_attention(layer_name, _placeholder): def basic_cache( - to_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size] - kv_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size] + to_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size] + kv_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size] slot_mapping: torch.Tensor, # shape: [seq_len] ): - num_blocks, block_size, num_heads, head_size = kv_cache.shape - token_kv_cache = kv_cache.view(num_blocks * block_size, num_heads, head_size) - token_kv_cache[slot_mapping] = to_cache + block_size = kv_cache.shape[1] + kv_cache[slot_mapping // block_size, slot_mapping % block_size] = to_cache ######### CacheOnlyAttentionBackend ######## @@ -322,11 +321,9 @@ def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: - # Note: we use MLAAttentionSpec here to because it will - # produce page sizes of (block_size * num_kv_heads * head_size * dtype_size) - # whereas FullAttentionSpec will add an additional factor of 2 - return MLAAttentionSpec( - block_size=self.block_size, + # Re-read block_size: hybrid models may bump it after __init__. + return HiddenStateCacheSpec( + block_size=vllm_config.cache_config.block_size, num_kv_heads=self.num_heads, head_size=self.head_size, dtype=self.kv_cache_torch_dtype, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b57e10b67faa..57182734c761 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -18,9 +18,11 @@ from vllm.utils.hashing import sha256_cbor, xxhash_cbor from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_utils import format_gib +from vllm.utils.torch_utils import get_dtype_size from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, FullAttentionSpec, + HiddenStateCacheSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, @@ -1650,15 +1652,33 @@ def get_kv_cache_groups( _annotate_eagle_groups_deepseek_v4(vllm_config, kv_cache_spec, kv_cache_groups) return kv_cache_groups + # Pull HiddenStateCacheSpec layers out before the general multi-group + # path so they don't affect page-size unification or grouping. + hidden_specs = { + k: v for k, v in kv_cache_spec.items() if isinstance(v, HiddenStateCacheSpec) + } + filtered_spec = { + k: v + for k, v in kv_cache_spec.items() + if not isinstance(v, HiddenStateCacheSpec) + } + # As KVCacheManager can only allocate memory of one size, we need to unify # the page size of the layers. For cases cannot be unified, this function # will raise an error. - kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec) - # Model contains multiple attention types, but KV cache of all layers - # have the same physical memory per block per layer. Split the layers - # into groups with the same number of layers, and thus same total page - # size. - return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) + filtered_spec = unify_kv_cache_spec_page_size(filtered_spec) + groups = _get_kv_cache_groups_uniform_page_size(filtered_spec) + + # Add hidden-state layers back with page aligned to the common page. + if hidden_specs: + common_page = get_uniform_page_size([g.kv_cache_spec for g in groups]) + for name, spec in hidden_specs.items(): + per_token = spec.num_kv_heads * spec.head_size * get_dtype_size(spec.dtype) + new_bs = max(common_page // per_token, 1) + aligned = replace(spec, block_size=new_bs, page_size_padded=common_page) + groups.append(KVCacheGroupSpec([name], aligned)) + + return groups def generate_scheduler_kv_cache_config( diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index e8d3a6f75688..7e23f176ec30 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -16,6 +16,7 @@ ChunkedLocalAttentionSpec, CrossAttentionSpec, FullAttentionSpec, + HiddenStateCacheSpec, KVCacheSpec, MambaSpec, MLAAttentionSpec, @@ -1143,6 +1144,7 @@ def __init__( FullAttentionSpec: FullAttentionManager, TQFullAttentionSpec: FullAttentionManager, MLAAttentionSpec: FullAttentionManager, + HiddenStateCacheSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, SlidingWindowMLASpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index cf50dbff179a..810e4c30f913 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -383,6 +383,13 @@ def merge(cls, specs: list[Self]) -> Self: ) +@dataclass(frozen=True, kw_only=True) +class HiddenStateCacheSpec(MLAAttentionSpec): + """Marker for hidden-state cache layers used by extract_hidden_states.""" + + pass + + @dataclass(frozen=True, kw_only=True) class ChunkedLocalAttentionSpec(AttentionSpec): attention_chunk_size: int diff --git a/vllm/v1/spec_decode/extract_hidden_states.py b/vllm/v1/spec_decode/extract_hidden_states.py index 380836bb3ac9..26ba0352e2a7 100644 --- a/vllm/v1/spec_decode/extract_hidden_states.py +++ b/vllm/v1/spec_decode/extract_hidden_states.py @@ -42,6 +42,7 @@ def __init__(self, vllm_config: VllmConfig, device): self.model: nn.Module | None = None self.attn_layer_names: list[str] = [] self.attn_metadata_builder: AttentionMetadataBuilder | None = None + self.kv_cache_gid: int = -1 # Maximum number of tokens for buffers max_batch_size = vllm_config.scheduler_config.max_num_seqs @@ -374,9 +375,12 @@ def load_model(self, target_model: nn.Module) -> None: ) def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: - """Validate all drafting layers belong to the same KV cache group. - - With exactly one attention layer (asserted in load_model), this is - trivially satisfied. - """ + """Validate all drafting layers belong to the same KV cache group + and record the group index for common_attn_metadata selection.""" assert len(self.attn_layer_names) == 1 + layer = self.attn_layer_names[0] + for gid, group in enumerate(kv_cache_config.kv_cache_groups): + if layer in group.layer_names: + self.kv_cache_gid = gid + return + raise ValueError(f"Cache-only layer {layer!r} not in any KV cache group") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index df061255b3a5..ea0be5343413 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2342,7 +2342,13 @@ def _build_attn_group_metadata( if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance( - self.drafter, (EagleProposer, DFlashProposer, Gemma4Proposer) + self.drafter, + ( + EagleProposer, + DFlashProposer, + Gemma4Proposer, + ExtractHiddenStatesProposer, + ), ): if self.drafter.kv_cache_gid == kv_cache_gid: spec_decode_common_attn_metadata = cm