Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 172 additions & 2 deletions tests/v1/kv_connector/unit/test_multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@
from unittest.mock import MagicMock

import pytest
import torch

from tests.v1.kv_connector.unit.utils import create_vllm_config
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
SupportsHMA,
supports_hma,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
Expand Down Expand Up @@ -83,8 +88,43 @@ def update_state_after_alloc(self, request, blocks, num_tokens) -> None:
pass


# Register the mock connector
class MockHMAConnector(KVConnectorBase_V1, SupportsHMA):
"""Mock connector that supports HMA for testing."""

def __new__(cls, *args, **kwargs):
mock = MagicMock(spec_set=cls)
return mock

def start_load_kv(self, forward_context, **kwargs):
pass

def wait_for_layer_load(self, layer_name):
pass

def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
pass

def wait_for_save(self):
pass

def build_connector_meta(self, scheduler_output):
return None

def get_num_new_matched_tokens(self, request, num_computed_tokens):
return (0, False)

def update_state_after_alloc(self, request, blocks, num_tokens) -> None:
pass

def request_finished_all_groups(self, request, block_ids):
return (False, None)


# Register mock connectors
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
KVConnectorFactory.register_connector(
"MockHMAConnector", __name__, MockHMAConnector.__name__
)


@pytest.fixture
Expand Down Expand Up @@ -920,3 +960,133 @@ def assert_update_connector_output_called(mc: MultiConnector):
mc.update_connector_output(kv_connector_output)
assert_update_connector_output_called(mc)
assert kv_connector_output.kv_connector_worker_meta == mc_worker_meta_01a_01b


def _make_multi_connector(connector_names: list[str]) -> MultiConnector:
"""Build a MultiConnector wrapping the given registered connectors."""
vllm_config = create_vllm_config()
connectors = [
{
"kv_connector": name,
"kv_role": "kv_both",
"kv_connector_module_path": "tests.v1.kv_connector.unit.test_multi_connector", # noqa: E501
}
for name in connector_names
]
vllm_config.kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={"connectors": connectors},
)
kv_cache_config = KVCacheConfig(
num_blocks=0,
kv_cache_tensors=[],
kv_cache_groups=[],
)
return MultiConnector(
vllm_config=vllm_config,
role=KVConnectorRole.WORKER,
kv_cache_config=kv_cache_config,
)


def test_multi_connector_hma_opt_in():
"""
MultiConnector currently assumes HMA is opt-in: it needs
--no-disable-hybrid-kv-cache-manager to be enabled.

At runtime, _all_support_hma is True only when every sub-connector
implements SupportsHMA. Test all combinations of HMA / non-HMA
sub-connectors.
"""

assert supports_hma(MultiConnector)

# -- All non-HMA connectors => _all_support_hma is False --
mc_none = _make_multi_connector(["MockConnector", "MockConnector"])
assert not supports_hma(mc_none._connectors[0])
assert not supports_hma(mc_none._connectors[1])
assert mc_none._all_support_hma is False

# -- All HMA connectors => _all_support_hma is True --
mc_all = _make_multi_connector(["MockHMAConnector", "MockHMAConnector"])
assert supports_hma(mc_all._connectors[0])
assert supports_hma(mc_all._connectors[1])
assert mc_all._all_support_hma is True

# -- Mixed: first HMA, second non-HMA => _all_support_hma is False --
mc_mixed1 = _make_multi_connector(["MockHMAConnector", "MockConnector"])
assert supports_hma(mc_mixed1._connectors[0])
assert not supports_hma(mc_mixed1._connectors[1])
assert mc_mixed1._all_support_hma is False

# -- Mixed: first non-HMA, second HMA => _all_support_hma is False --
mc_mixed2 = _make_multi_connector(["MockConnector", "MockHMAConnector"])
assert not supports_hma(mc_mixed2._connectors[0])
assert supports_hma(mc_mixed2._connectors[1])
assert mc_mixed2._all_support_hma is False


@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires GPU to instantiate LLM"
)
def test_multi_connector_mixed_hma_disables_hybrid_kv_cache(monkeypatch):
"""
When MultiConnector wraps a mix of HMA (NixlConnector) and non-HMA
(MockConnector) sub-connectors, verify that:
1. The scheduler's MultiConnector has _all_support_hma == False.
2. vLLM auto-disables the hybrid KV cache manager (no preference expressed by user)
"""
from unittest.mock import patch

from tests.v1.kv_connector.unit.test_nixl_connector import FakeNixlWrapper

monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")

kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [
{
"kv_connector": "NixlConnector",
"kv_role": "kv_both",
},
{
"kv_connector": "MockConnector",
"kv_role": "kv_both",
"kv_connector_module_path": (
"tests.v1.kv_connector.unit.test_multi_connector"
),
},
],
},
)

with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper,
):
llm = LLM(
model="Qwen/Qwen3-0.6B",
enforce_eager=True,
gpu_memory_utilization=0.3,
max_model_len=128,
max_num_seqs=1,
max_num_batched_tokens=128,
kv_transfer_config=kv_transfer_config,
)
try:
# HMA should be auto-disabled when user has not expressed a preference.
assert (
llm.llm_engine.vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
is True
)
# The scheduler-side MultiConnector should detect the mixed
# HMA support among its sub-connectors.
scheduler = llm.llm_engine.engine_core.engine_core.scheduler
mc = scheduler.connector
assert isinstance(mc, MultiConnector)
assert mc._all_support_hma is False
finally:
llm.llm_engine.engine_core.shutdown()
52 changes: 45 additions & 7 deletions vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

import torch

Expand All @@ -18,6 +18,8 @@
KVConnectorMetadata,
KVConnectorRole,
KVConnectorWorkerMetadata,
SupportsHMA,
supports_hma,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
Expand Down Expand Up @@ -123,7 +125,7 @@ def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx)


class MultiConnector(KVConnectorBase_V1):
class MultiConnector(KVConnectorBase_V1, SupportsHMA):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is problematic.
HMA will be enabled though you may have sub-connectors which don't support it.
Their update_state_after_alloc and build_connector_meta may fail.

I think we need to replace SupportsHMA with a runtime check like prefers_cross_layer_blocks.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orozery I am doing a runtime check in this PR with

        self._all_support_hma = all(supports_hma(c) for c in self._connectors)

to fall back for the request_finished method.
I think the easiest solution might be to just ensure hma is disabled for such cases.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think hma will be enabled if and only if multiconnector SupportsHMA

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually let me rephrase that.
If the user explicitly sets --no-disable-hybrid-kv-cache-manager and still packs a sub-connector with no HMA support, due to the runtime check above, any hma-requiring model (Mamba/SW etc) will crash here

    def request_finished_all_groups(
        self,
        request: "Request",
        block_ids: tuple[list[int], ...],
    ) -> tuple[bool, dict[str, Any] | None]:
        if not self._all_support_hma:
     =>     assert len(block_ids) == 1, (
                "HMA with multiple kv_cache_groups requires all "
                "sub-connectors to support HMA"
            )

do you have another scenario in mind? @orozery

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, hma is enabled, unless using a connector which does not have SupportsHMA, right?

If so, this PR will enable HMA by default for the multi connector, even if sub connectors do not support HMA.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, if connector is configured, HMA will be disabled by default unless users explicitly set --no-disable-hybrid-kv-cache-manager as @NickLucche mentioned.
With that being said, I agree that having a connector API for supports_hma() (and change kv_connector/v1/base.py::supports_hma() to call connector's API) is a more robust solution.
One quick alternative for now is to change kv_connector/v1/base.py::supports_hma() with a special path for multi connector, which iteratively checks each connector, and making the rest of API changes as another follow-up PR.

Copy link
Copy Markdown
Collaborator

@orozery orozery Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this means that if you use a connector you must set either:
--disable-hybrid-kv-cache-manager or --no-disable-hybrid-kv-cache-manager
Right?

This is weird IMO.
I think that using a connector which supports HMA should work (with HMA enabled) without requiring any more CLI, just as HMA works without explicit enablement if not using a connector.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this means that if you use a connector you must set either:

nope it's just HMA is opt-in for Connector still.
This is the logic for HMA on/off

vllm/vllm/config/vllm.py

Lines 1274 to 1277 in 6f786f2

if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to disable HMA, but only if the user didn't express a preference.
if self.kv_transfer_config is not None:
# NOTE(Kuntai): turn HMA off for connector unless specifically enabled.

but I don't think that needs to be changed with this PR, imo the scope is narrower here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orozery @ivanium
I also think a kv_connector/v1/base.py::supports_hma() interface change is in order, but that should be a follow-up PR (which I can still put up in the meantime), where we can also look at changing the auto-on/off behavior.

But I think right now, given the hma feature is still experimental/opt-in, the changes proposed here are safe.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with this PR and addressing the remaining UX issues in a follow-up PR

"""
A wrapper for using multiple KVConnectors at the same time.

Expand Down Expand Up @@ -166,6 +168,12 @@ def __init__(
self._connectors.append(connector_cls(temp_config, role, kv_cache_config))
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)

self._all_support_hma = all(supports_hma(c) for c in self._connectors)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to add here:

assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager or self._all_support_hma

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orozery I have a model-dependent check here

        if not self._all_support_hma:
            assert len(block_ids) == 1, (
                "HMA with multiple kv_cache_groups requires all "
                "sub-connectors to support HMA"
            )

which is more confined.
eg If we're serving llama3 it doesn't really matter whether hma is disabled or not.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should assert on __init__ when booting instead of waiting for request_finished..
HMA should not be enabled when using a non-supporting connector, even if the model does not utilize HMA (e.g. llama3).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the assert

assert (
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
or self._all_support_hma
), "HMA should not be enabled unless all sub-connectors support it"

# A mapping from request id to the index of the connector chosen to
# load the request from (if any).
self._requests_to_connector: dict[str, int] = {}
Expand Down Expand Up @@ -436,15 +444,17 @@ def set_xfer_handshake_metadata(
for c in self._connectors:
c.set_xfer_handshake_metadata(metadata)

def request_finished(
def _aggregate_request_finished(
self,
request: "Request",
blocks: list[int],
per_connector_fn: Callable[
[KVConnectorBase_V1], tuple[bool, dict[str, Any] | None]
],
) -> tuple[bool, dict[str, Any] | None]:
async_saves = 0
kv_txfer_params = None
for c in self._connectors:
async_save, txfer_params = c.request_finished(request, blocks)
async_save, txfer_params = per_connector_fn(c)
if async_save:
async_saves += 1
if txfer_params is not None:
Expand All @@ -458,11 +468,39 @@ def request_finished(
if async_saves > 1:
self._extra_async_saves[request.request_id] = async_saves - 1

# Clean up other state for this request.
self._requests_to_connector.pop(request.request_id, None)

return async_saves > 0, kv_txfer_params

def request_finished(
self,
request: "Request",
blocks: list[int],
) -> tuple[bool, dict[str, Any] | None]:
return self._aggregate_request_finished(
request,
lambda c: c.request_finished(request, blocks),
)

def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
if not self._all_support_hma:
assert len(block_ids) == 1, (
"HMA with multiple kv_cache_groups requires all "
"sub-connectors to support HMA"
)
Comment thread
NickLucche marked this conversation as resolved.
return self.request_finished(request, block_ids[0])

return self._aggregate_request_finished(
request,
lambda c: cast(SupportsHMA, c).request_finished_all_groups(
request, block_ids
),
)

def take_events(self) -> Iterable["KVCacheEvent"]:
for c in self._connectors:
yield from c.take_events()
Expand Down
Loading