Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

import vllm.v1.core.kv_cache_utils as kv_cache_utils
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.config import KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.config.kv_events import KVEventsConfig
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (
Expand Down Expand Up @@ -2165,3 +2165,29 @@ def test_hma_not_disabled_when_kv_events_enabled():
assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False, (
"kv_events_config must not force-disable the hybrid KV cache manager."
)


def test_hma_not_disabled_for_supported_kv_connector():
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
)

vllm_config = VllmConfig(
kv_transfer_config=kv_transfer_config,
)

assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False


def test_hma_disabled_for_unsupported_kv_connector():
kv_transfer_config = KVTransferConfig(
kv_connector="ExampleConnector",
kv_role="kv_both",
)

vllm_config = VllmConfig(
kv_transfer_config=kv_transfer_config,
)

assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is True
46 changes: 34 additions & 12 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,18 +1288,40 @@ def has_blocked_weights():
if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to disable HMA, but only if the user didn't express a preference.
if self.kv_transfer_config is not None:
# NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
need_disable_hybrid_kv_cache_manager = True
logger.warning(
"Turning off hybrid kv cache manager because "
"`--kv-transfer-config` is set. This will reduce the "
"performance of vLLM on LLMs with sliding window attention "
"or Mamba attention. If you are a developer of kv connector"
", please consider supporting hybrid kv cache manager for "
"your connector by making sure your connector is a subclass"
" of `SupportsHMA` defined in kv_connector/v1/base.py and"
" use --no-disable-hybrid-kv-cache-manager to start vLLM."
)
# NOTE(Kuntai): turn HMA off for connector unless it explicitly
# advertises support.
try:
# Lazy import to avoid circular dependencies.
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import (
supports_hma,
)

connector_cls = KVConnectorFactory.get_connector_class(
self.kv_transfer_config
)
connector_supports_hma = supports_hma(connector_cls)
except Exception:
logger.debug(
"Failed to check whether KV connector supports HMA.",
exc_info=True,
)
connector_supports_hma = False

if not connector_supports_hma:
need_disable_hybrid_kv_cache_manager = True
logger.warning(
"Turning off hybrid kv cache manager because "
"`--kv-transfer-config` is set. This will reduce the "
"performance of vLLM on LLMs with sliding window attention "
"or Mamba attention. If you are a developer of kv connector"
", please consider supporting hybrid kv cache manager for "
"your connector by making sure your connector is a subclass"
" of `SupportsHMA` defined in kv_connector/v1/base.py and"
" use --no-disable-hybrid-kv-cache-manager to start vLLM."
)
self.scheduler_config.disable_hybrid_kv_cache_manager = (
need_disable_hybrid_kv_cache_manager
)
Expand Down
Loading