Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions tests/v1/kv_connector/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,61 @@ def test_kv_connector(
assert "existing_key" not in kv_connector_extra_config


def _build_config(
*, kv_connector: str | None, enable_sleep_mode: bool = False
) -> VllmConfig:
"""Build a VllmConfig that exercises _verify_kv_transfer_compat without
requiring a real model (avoids HF downloads in CI)."""
from types import SimpleNamespace

kv_transfer_config = (
KVTransferConfig(kv_connector=kv_connector, kv_role="kv_both")
if kv_connector is not None
else None
)
cfg = VllmConfig.__new__(VllmConfig)
cfg.kv_transfer_config = kv_transfer_config
cfg.model_config = SimpleNamespace(enable_sleep_mode=enable_sleep_mode)
cfg._verify_kv_transfer_compat()
return cfg


@pytest.mark.parametrize(
"kv_connector", ["NixlConnector", "MooncakeConnectorV1", "SomeOOTConnector"]
)
def test_kv_connector_rejects_expandable_segments(monkeypatch, kv_connector):
"""KV connectors that pin KV cache memory (e.g. via ibv_reg_mr) are
invalidated when expandable_segments lets the CUDA VMM allocator remap
the underlying physical pages. We can't enumerate every connector that
does this (especially OOT ones), so reject the combination whenever any
connector is configured."""
monkeypatch.setenv("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
with pytest.raises(ValueError, match="expandable_segments"):
_build_config(kv_connector=kv_connector)


def test_kv_connector_allows_expandable_segments_with_sleep_mode(monkeypatch):
"""Sleep mode routes KV allocations through CuMemAllocator's pool, which
auto-disables expandable_segments (see #40812)."""
monkeypatch.setenv("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
_build_config(kv_connector="NixlConnector", enable_sleep_mode=True)


def test_kv_connector_allows_other_alloc_conf(monkeypatch):
"""Other PYTORCH_CUDA_ALLOC_CONF values must not be rejected."""
monkeypatch.setenv(
"PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:512,expandable_segments:False"
)
_build_config(kv_connector="NixlConnector")


def test_no_kv_connector_ignores_expandable_segments(monkeypatch):
"""The expandable_segments check only applies when a KV connector is
configured."""
monkeypatch.setenv("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
_build_config(kv_connector=None)


def test_kv_offloading_size_only_uses_native_default():
"""Test that setting only kv_offloading_size enables native offloading."""
vllm_config = VllmConfig(
Expand Down
43 changes: 43 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,48 @@ def _post_init_kv_transfer_config(self) -> None:
# This is the same for all backends
self.kv_transfer_config.kv_role = "kv_both"

def _verify_kv_transfer_compat(self) -> None:
"""Reject configurations that silently corrupt KV transfers."""
if (
self.kv_transfer_config is None
or self.kv_transfer_config.kv_connector is None
):
return

# PyTorch's expandable_segments allocator uses CUDA VMM, which can
# remap a virtual address range to different physical pages over the
# engine's lifetime. KV connectors that pin KV cache memory (e.g.
# NixlConnector via ibv_reg_mr, MooncakeConnector) end up with their
# registrations pointing at stale physical pages after any remap,
# producing RDMA failures like IBV_WC_REM_ACCESS_ERR /
# NIXL_ERR_REMOTE_DISCONNECT at the first inter-node KV transfer.
# We can't enumerate every in-tree and out-of-tree connector that
# pins memory, so we conservatively reject the combination whenever
# any KV connector is configured.
#
# Sleep mode is exempt: CuMemAllocator.use_memory_pool toggles
# expandable_segments off around its pool (see #40812), so the KV
# cache allocated within that context lands on stable physical pages
# even when the env var is set.
if "expandable_segments:True" not in os.environ.get(
"PYTORCH_CUDA_ALLOC_CONF", ""
):
return
if self.model_config is not None and self.model_config.enable_sleep_mode:
return

raise ValueError(
f"KV connector {self.kv_transfer_config.kv_connector} is "
"incompatible with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True "
"unless enable_sleep_mode is also enabled. PyTorch's CUDA VMM "
"allocator can remap KV cache virtual addresses to different "
"physical pages, invalidating any pinned/registered KV memory "
"(e.g. IB memory regions registered by NIXL or Mooncake). Either "
"unset expandable_segments:True or enable sleep mode (which "
"routes KV allocations through CuMemAllocator's pool, where "
"expandable_segments is automatically disabled)."
)

def __post_init__(self):
"""Verify configs are valid & consistent with each other."""

Expand Down Expand Up @@ -1343,6 +1385,7 @@ def has_blocked_weights():

# Handle the KV connector configs
self._post_init_kv_transfer_config()
self._verify_kv_transfer_compat()

# Log the custom passes that are enabled
self.compilation_config.pass_config.log_enabled_passes()
Expand Down
Loading