Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions tests/v1/kv_connector/unit/test_hma_auto_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Regression tests for HMA auto-disable with KV transfer connectors."""

import pytest

from vllm.config import DeviceConfig, KVTransferConfig, SchedulerConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.platforms import current_platform
from vllm.v1.kv_cache_interface import KVCacheConfig

pytestmark = pytest.mark.cpu_test


@pytest.fixture(autouse=True)
def mock_hybrid_kv_cache_supported(monkeypatch):
monkeypatch.setattr(current_platform, "support_hybrid_kv_cache", lambda: True)


@pytest.mark.parametrize(
"kv_transfer_config,expect_disabled",
[
( # HMA-supporting connector → HMA stays enabled
KVTransferConfig(
kv_connector="SimpleCPUOffloadConnector",
kv_role="kv_both",
kv_connector_extra_config={"cpu_bytes_to_use": 1 << 30},
),
False,
),
( # Non-HMA connector → HMA is auto-disabled
KVTransferConfig(kv_connector="ExampleConnector", kv_role="kv_both"),
True,
),
( # MultiConnector: all HMA children → HMA stays enabled
KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [
{
"kv_connector": "SimpleCPUOffloadConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {"cpu_bytes_to_use": 1 << 30},
},
{
"kv_connector": "OffloadingConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {"cpu_bytes_to_use": 1 << 30},
},
]
},
),
False,
),
( # MultiConnector: mixed children → HMA is auto-disabled
KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [
{
"kv_connector": "SimpleCPUOffloadConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {"cpu_bytes_to_use": 1 << 30},
},
{"kv_connector": "ExampleConnector", "kv_role": "kv_both"},
]
},
),
True,
),
],
ids=["hma_connector", "non_hma_connector", "multi_all_hma", "multi_mixed"],
)
def test_hma_auto_config(kv_transfer_config, expect_disabled):
vllm_config = VllmConfig(
device_config=DeviceConfig("cpu"),
kv_transfer_config=kv_transfer_config,
)
assert (
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is expect_disabled
)


def test_explicit_hma_with_non_hma_connector_errors_at_factory():
vllm_config = VllmConfig(
device_config=DeviceConfig("cpu"),
scheduler_config=SchedulerConfig(
max_model_len=16,
is_encoder_decoder=False,
disable_hybrid_kv_cache_manager=False,
),
kv_transfer_config=KVTransferConfig(
kv_connector="ExampleConnector",
kv_role="kv_both",
),
)
kv_cache_config = KVCacheConfig(
num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[]
)
with pytest.raises(ValueError, match="does not support HMA but HMA is enabled"):
KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
5 changes: 1 addition & 4 deletions tests/v1/kv_connector/unit/test_multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,11 +1000,8 @@ def _make_multi_connector(connector_names: list[str]) -> MultiConnector:
)


def test_multi_connector_hma_opt_in():
def test_multi_connector_hma_support_detection():
"""
MultiConnector currently assumes HMA is opt-in: it needs
--no-disable-hybrid-kv-cache-manager to be enabled.
At runtime, _all_support_hma is True only when every sub-connector
implements SupportsHMA. Test all combinations of HMA / non-HMA
sub-connectors.
Expand Down
3 changes: 1 addition & 2 deletions tests/v1/kv_connector/unit/test_nixl_connector_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,8 +723,7 @@ def test_has_mamba_init(

block_size = 16
vllm_config = create_vllm_config(block_size=block_size)
# VllmConfig.__post_init__ auto-disables HMA when kv_transfer_config
# is set; override so we can test the scheduler's own derivation.
# Explicitly enable HMA so we can test the scheduler's own derivation.
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
kv_cache_config = make_kv_cache_config(
block_size=block_size,
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/kv_connector/unit/test_offloading_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_cpu_offloading(
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config,
**({"attention_config": {"backend": attn_backend}} if attn_backend else {}),
# HMA models need explicit opt-in when kv_transfer_config is set
# Keep HMA explicitly enabled for HMA model coverage.
**({"disable_hybrid_kv_cache_manager": False} if uses_hma else {}),
**({"enable_prefix_caching": True} if force_prefix_caching else {}),
# ROCm: batch size 1 to reduce variability
Expand Down
42 changes: 12 additions & 30 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ def has_blocked_weights():
# Hybrid KV cache manager (HMA) runtime rules:
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime
# disables it
# - No preference: auto-disable for unsupported features (e.g. kv connector)
# - No preference: auto-disable for unsupported features or connector configs
# - Explicit disable (--disable-kv-cache-manager): always respect it
need_disable_hybrid_kv_cache_manager = False
# logger should only print warning message for hybrid models. As we
Expand Down Expand Up @@ -1438,43 +1438,25 @@ def has_blocked_weights():
need_disable_hybrid_kv_cache_manager = True

if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to disable HMA, but only if the user didn't express a preference.
# Auto-disable HMA only when the connector config does not support it.
if self.kv_transfer_config is not None:
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
supports_hma,
)

connector_cls = KVConnectorFactory.get_connector_class(
self.kv_transfer_config
)
all_support_hma = supports_hma(connector_cls)
# MultiConnector subclasses SupportsHMA; only effectively
# supports HMA when every sub-connector does.
if all_support_hma and connector_cls.__name__ == "MultiConnector":
sub_ktcs = self.kv_transfer_config.kv_connector_extra_config.get(
"connectors", []
)
all_support_hma = all(
supports_hma(
KVConnectorFactory.get_connector_class(
KVTransferConfig(**sub)
)
)
for sub in sub_ktcs
)
if not all_support_hma:
if not KVConnectorFactory.supports_hma_config(self.kv_transfer_config):
need_disable_hybrid_kv_cache_manager = True
logger.warning(
"Turning off hybrid kv cache manager because "
"connector %s does not subclass `SupportsHMA`. "
"This will reduce performance on models with "
"sliding window or Mamba attention. See "
"kv_connector/v1/base.py for details.",
connector_cls.__name__,
"`--kv-transfer-config` selects a KV connector that "
"does not support it. Impact: hybrid SSM models "
"(e.g. Jamba, Bamba) require HMA and will fail at "
"startup without it; models with sliding window "
"attention will run with reduced performance. "
"To add HMA support to a KV connector, subclass "
"`SupportsHMA` defined in kv_connector/v1/base.py "
"(for MultiConnector, all child connectors must "
"support HMA)."
)
self.scheduler_config.disable_hybrid_kv_cache_manager = (
need_disable_hybrid_kv_cache_manager
Expand Down
21 changes: 19 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Callable
from typing import TYPE_CHECKING, cast

from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase,
KVConnectorBaseType,
Expand All @@ -18,7 +19,6 @@

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.v1.kv_cache_interface import KVCacheConfig

logger = init_logger(__name__)
Expand Down Expand Up @@ -53,7 +53,7 @@ def create_connector(

# check if the connector supports HMA
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
if hma_enabled and not supports_hma(connector_cls):
if hma_enabled and not cls.supports_hma_config(kv_transfer_config):
raise ValueError(
f"Connector {connector_cls.__name__} does not support HMA but "
f"HMA is enabled. Please set `--disable-hybrid-kv-cache-manager`."
Expand Down Expand Up @@ -127,6 +127,23 @@ def get_connector_class(
raise ValueError(f"Unsupported connector type: {connector_name}")
return connector_cls

@classmethod
def supports_hma_config(cls, kv_transfer_config: "KVTransferConfig") -> bool:
"""Return whether this KV transfer config supports HMA.

MultiConnector is a special case: the wrapper class implements
SupportsHMA, but effective support depends on every configured child.
"""
connector_cls = cls.get_connector_class(kv_transfer_config)
if kv_transfer_config.kv_connector != "MultiConnector":
return supports_hma(connector_cls)

from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
)

return MultiConnector.all_children_support_hma(kv_transfer_config)


# Register various connectors here.
# The registration should not be done in each individual file, as we want to
Expand Down
22 changes: 20 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
KVConnectorRole,
KVConnectorWorkerMetadata,
SupportsHMA,
supports_hma,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
Expand Down Expand Up @@ -151,6 +150,22 @@ def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
return True
return False

@classmethod
def all_children_support_hma(cls, kv_transfer_config: "KVTransferConfig") -> bool:
"""Return True only if every configured child connector supports HMA."""
connectors_config = kv_transfer_config.kv_connector_extra_config.get(
"connectors", []
)
if not connectors_config:
return False
for conn_config in connectors_config:
child_config = KVTransferConfig(
**{"engine_id": kv_transfer_config.engine_id, **conn_config}
)
if not KVConnectorFactory.supports_hma_config(child_config):
return False
return True

def __init__(
self,
vllm_config: "VllmConfig",
Expand All @@ -169,7 +184,10 @@ def __init__(
self._connectors.append(connector_cls(temp_config, role, kv_cache_config))
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)

self._all_support_hma = all(supports_hma(c) for c in self._connectors)
assert vllm_config.kv_transfer_config is not None
self._all_support_hma = MultiConnector.all_children_support_hma(
vllm_config.kv_transfer_config
)
assert (
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
or self._all_support_hma
Expand Down
Loading