Skip to content
Merged
13 changes: 13 additions & 0 deletions .buildkite/test_areas/misc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,19 @@ steps:
- pytest -v -s -m 'cpu_test' v1/kv_connector/unit
- pytest -v -s -m 'cpu_test' v1/metrics

- label: Extract Hidden States Integration
key: extract-hidden-states-integration
timeout_in_minutes: 20
device: h200_18gb
source_file_dependencies:
- vllm/v1/spec_decode/extract_hidden_states.py
- vllm/model_executor/models/extract_hidden_states.py
- vllm/transformers_utils/configs/extract_hidden_states.py
- tests/v1/kv_connector/extract_hidden_states_integration
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s v1/kv_connector/extract_hidden_states_integration

- label: Regression
key: regression
timeout_in_minutes: 20
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def register_predictable_model():


def test_extract_hidden_states_with_predictable_dummy_model(
predictable_llama_config_path, tmp_path
predictable_llama_config_path, tmp_path, monkeypatch
):
"""Comprehensive test using a predictable dummy model with synthetic weights.

Expand All @@ -94,6 +94,12 @@ def test_extract_hidden_states_with_predictable_dummy_model(
3. Layer ordering is preserved correctly (non-sequential layer IDs)
4. Multiple prompts of different lengths produce consistent layer values
"""
# Force fork so the engine worker inherits the autouse fixture's
# ModelRegistry.register_model("PredictableLlamaForCausalLM", ...).
# Spawn (the CI default) starts a fresh Python process that wouldn't
# see the registration.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")

# Test with non-sequential layer ordering to verify correct association
layer_ids = [5, 2, 10]
num_layers = len(layer_ids)
Expand Down Expand Up @@ -153,3 +159,55 @@ def test_extract_hidden_states_with_predictable_dummy_model(
f"but got mean={layer_hidden.mean():.3f}, "
f"min={layer_hidden.min():.3f}, max={layer_hidden.max():.3f}"
)


def test_extract_hidden_states_qwen35_hybrid_smoke(tmp_path):
"""Smoke test for Qwen3.5 hybrid (mamba + full-attention) models.
Uses load_format="dummy" to just check shape/plumbing.
"""
layer_ids = [5, 11, 17]
hidden_size = 1024 # Qwen/Qwen3.5-0.8B hidden_size

llm = LLM(
model="Qwen/Qwen3.5-0.8B",
speculative_config={
"method": "extract_hidden_states",
"num_speculative_tokens": 1,
"draft_model_config": {
"hf_config": {"eagle_aux_hidden_state_layer_ids": layer_ids}
},
},
kv_transfer_config={
"kv_connector": "ExampleHiddenStatesConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {"shared_storage_path": str(tmp_path)},
},
max_model_len=256,
enforce_eager=True,
gpu_memory_utilization=0.4,
load_format="dummy",
)

prompts = ["Hello world", "Test prompt with several tokens"]
sampling_params = SamplingParams(max_tokens=1, temperature=0.0)
outputs = llm.generate(prompts, sampling_params)
del llm
gc.collect()

assert len(outputs) == len(prompts)
for output in outputs:
assert output.kv_transfer_params is not None
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
assert hidden_states_path is not None
assert os.path.exists(hidden_states_path)

with safe_open(hidden_states_path, "pt") as f:
token_ids = f.get_tensor("token_ids")
hidden_states = f.get_tensor("hidden_states")

assert torch.equal(token_ids, torch.tensor(output.prompt_token_ids))
assert hidden_states.shape == (
len(output.prompt_token_ids),
len(layer_ids),
hidden_size,
)
7 changes: 2 additions & 5 deletions tests/v1/kv_connector/unit/test_decode_bench_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch

from vllm import SamplingParams
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole

# ruff: noqa: E501
Expand Down Expand Up @@ -44,11 +43,9 @@ def __init__(self, block_size: int, num_gpu_blocks: int):

# Create vllm config with DecodeBenchConnector
vllm_config = create_vllm_config(
block_size=block_size, max_num_batched_tokens=1000
)
vllm_config.kv_transfer_config = KVTransferConfig(
block_size=block_size,
max_num_batched_tokens=1000,
kv_connector="DecodeBenchConnector",
kv_role="kv_both",
)

self.vllm_config = vllm_config
Expand Down
9 changes: 5 additions & 4 deletions tests/v1/kv_connector/unit/test_kv_connector_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ def _make_empty_scheduler_output():


def test_kv_connector_mixin_clears_metadata():
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "TestExampleConnector"
vllm_config.kv_transfer_config.kv_role = "kv_both"
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit"
vllm_config = create_vllm_config(
kv_connector="TestExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={"name": "unit"},
)

# Initialize the global connector instance
kv_cache_config = KVCacheConfig(
Expand Down
57 changes: 24 additions & 33 deletions tests/v1/kv_connector/unit/test_multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,14 @@ def request_finished_all_groups(self, request, block_ids):
@pytest.fixture
def mc() -> MultiConnector:
"""MultiConnector using two mocked connectors"""
vllm_config = create_vllm_config()

mock_connector_config = {
"kv_connector": "MockConnector",
"kv_role": "kv_both",
"kv_connector_module_path": "tests.v1.kv_connector.unit.test_multi_connector",
}

vllm_config.kv_transfer_config = KVTransferConfig(
vllm_config = create_vllm_config(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [mock_connector_config, mock_connector_config],
},
Expand Down Expand Up @@ -403,39 +400,35 @@ def test_multi_connector_handle_preemptions_integration():

try:
# Configure MultiConnector with two TestExampleConnectors
kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [
{
"kv_connector": "TestExampleConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_path / "s1"),
"name": "preempt1",
},
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
connectors_extra_config = {
"connectors": [
{
"kv_connector": "TestExampleConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_path / "s1"),
"name": "preempt1",
},
{
"kv_connector": "TestExampleConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_path / "s2"),
"name": "preempt2",
},
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
},
{
"kv_connector": "TestExampleConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_path / "s2"),
"name": "preempt2",
},
]
},
)
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
},
]
}

vllm_config = create_vllm_config(
block_size=16,
max_num_batched_tokens=100,
kv_connector_extra_config=kv_transfer_config.kv_connector_extra_config,
kv_connector="MultiConnector",
kv_connector_extra_config=connectors_extra_config,
)
vllm_config.kv_transfer_config = kv_transfer_config

# Create scheduler - this initializes the MultiConnector with SCHEDULER role
scheduler = create_scheduler(vllm_config, num_blocks=10)
Expand Down Expand Up @@ -971,7 +964,6 @@ def assert_update_connector_output_called(mc: MultiConnector):

def _make_multi_connector(connector_names: list[str]) -> MultiConnector:
"""Build a MultiConnector wrapping the given registered connectors."""
vllm_config = create_vllm_config()
connectors = [
{
"kv_connector": name,
Expand All @@ -980,9 +972,8 @@ def _make_multi_connector(connector_names: list[str]) -> MultiConnector:
}
for name in connector_names
]
vllm_config.kv_transfer_config = KVTransferConfig(
vllm_config = create_vllm_config(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={"connectors": connectors},
)
kv_cache_config = KVCacheConfig(
Expand Down
2 changes: 2 additions & 0 deletions tests/v1/kv_connector/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def create_vllm_config(
attention_backend: str | None = None,
kv_load_failure_policy: Literal["recompute", "fail"] = "fail",
kv_connector: str = "NixlConnector",
kv_connector_module_path: str | None = None,
kv_role: str = "kv_both",
disable_hybrid_kv_cache_manager: bool | None = None,
) -> VllmConfig:
Expand Down Expand Up @@ -130,6 +131,7 @@ def create_vllm_config(
)
kv_transfer_config = KVTransferConfig(
kv_connector=kv_connector,
kv_connector_module_path=kv_connector_module_path,
kv_role=kv_role,
enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config=kv_connector_extra_config or {},
Expand Down
53 changes: 39 additions & 14 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,10 @@ def has_blocked_weights():
"the `reasoning_start_str` and `reasoning_end_str`."
)

# Resolve kv_offloading-derived connector name into kv_transfer_config
# before the HMA check below, which inspects the connector class.
self._post_init_kv_transfer_config()

# Hybrid KV cache manager (HMA) runtime rules:
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime
# disables it
Expand Down Expand Up @@ -1353,18 +1357,42 @@ def has_blocked_weights():
if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to disable HMA, but only if the user didn't express a preference.
if self.kv_transfer_config is not None:
# NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
need_disable_hybrid_kv_cache_manager = True
logger.warning(
"Turning off hybrid kv cache manager because "
"`--kv-transfer-config` is set. This will reduce the "
"performance of vLLM on LLMs with sliding window attention "
"or Mamba attention. If you are a developer of kv connector"
", please consider supporting hybrid kv cache manager for "
"your connector by making sure your connector is a subclass"
" of `SupportsHMA` defined in kv_connector/v1/base.py and"
" use --no-disable-hybrid-kv-cache-manager to start vLLM."
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
supports_hma,
)

connector_cls = KVConnectorFactory.get_connector_class(
self.kv_transfer_config
)
all_support_hma = supports_hma(connector_cls)
# MultiConnector subclasses SupportsHMA; only effectively
# supports HMA when every sub-connector does.
if all_support_hma and connector_cls.__name__ == "MultiConnector":
sub_ktcs = self.kv_transfer_config.kv_connector_extra_config.get(
"connectors", []
)
all_support_hma = all(
supports_hma(
KVConnectorFactory.get_connector_class(
KVTransferConfig(**sub)
)
)
for sub in sub_ktcs
)
if not all_support_hma:
need_disable_hybrid_kv_cache_manager = True
logger.warning(
"Turning off hybrid kv cache manager because "
"connector %s does not subclass `SupportsHMA`. "
"This will reduce performance on models with "
"sliding window or Mamba attention. See "
"kv_connector/v1/base.py for details.",
connector_cls.__name__,
)
Comment on lines +1371 to +1395
Copy link
Copy Markdown
Member

@NickLucche NickLucche May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we were tracking this change here #41847 and #42024 !

self.scheduler_config.disable_hybrid_kv_cache_manager = (
need_disable_hybrid_kv_cache_manager
)
Expand Down Expand Up @@ -1406,10 +1434,7 @@ def has_blocked_weights():
if "-quant_fp8" not in custom_ops:
custom_ops.append("+quant_fp8")

# Handle the KV connector configs
self._post_init_kv_transfer_config()
self._verify_kv_transfer_compat()

# Log the custom passes that are enabled
self.compilation_config.pass_config.log_enabled_passes()

Expand Down
Loading
Loading