Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions tests/v1/kv_connector/unit/test_backwards_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,95 @@ def test_signature_detection_with_mocking():
assert isinstance(new_connector, NewStyleTestConnector)
assert new_connector._kv_cache_config is not None
assert new_connector._kv_cache_config == kv_cache_config


# ---------------------------------------------------------------------------
# Regression tests for issue #40690:
# MultiConnector must route child construction through the compat-aware factory
# so that legacy 2-arg connectors work as children, not just at the top level.
# ---------------------------------------------------------------------------

_THIS_MODULE = "tests.v1.kv_connector.unit.test_backwards_compatibility"


@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_multiconnector_old_style_child_instantiation(role):
"""
Regression test for #40690.

A legacy connector with the old 2-arg __init__(self, vllm_config, role)
must work as a child of MultiConnector. Before the fix, MultiConnector
directly called connector_cls(config, role, kv_cache_config), which passed
a third positional argument that the old-style connector doesn't accept,
raising TypeError. After the fix, child construction goes through
KVConnectorFactory.create_connector, which applies the same compat shim
that the top-level path uses.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
)

vllm_config = create_vllm_config(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [
{
"kv_connector": "OldStyleTestConnector",
"kv_role": "kv_both",
"kv_connector_module_path": _THIS_MODULE,
}
]
},
)

scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config

# This must not raise TypeError before or after the fix.
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)

assert isinstance(connector, MultiConnector)
assert len(connector._connectors) == 1
child = connector._connectors[0]
assert isinstance(child, OldStyleTestConnector)
# Old-style connector receives no kv_cache_config
assert child._kv_cache_config is None


@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_multiconnector_new_style_child_unaffected(role):
"""
Ensure that the fix does not break new-style 3-arg children inside
MultiConnector.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
)

vllm_config = create_vllm_config(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [
{
"kv_connector": "NewStyleTestConnector",
"kv_role": "kv_both",
"kv_connector_module_path": _THIS_MODULE,
}
]
},
)

scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config

connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)

assert isinstance(connector, MultiConnector)
assert len(connector._connectors) == 1
child = connector._connectors[0]
assert isinstance(child, NewStyleTestConnector)
# New-style connector receives kv_cache_config
assert child._kv_cache_config is not None
assert child._kv_cache_config == kv_cache_config
44 changes: 28 additions & 16 deletions vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,12 @@ def __init__(

self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
for connector_cls, temp_config in self._get_connector_classes_and_configs(
vllm_config
):
self._connectors.append(connector_cls(temp_config, role, kv_cache_config))
for temp_config in self._get_child_configs(vllm_config):
connector = KVConnectorFactory.create_connector(
temp_config, role, kv_cache_config
)
assert isinstance(connector, KVConnectorBase_V1)
self._connectors.append(connector)
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)

# A mapping from request id to the index of the connector chosen to
Expand All @@ -183,31 +185,41 @@ def prefer_cross_layer_blocks(self) -> bool:
return all(c.prefer_cross_layer_blocks for c in self._connectors)

@classmethod
def _get_connector_classes_and_configs(
cls, vllm_config: "VllmConfig"
) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]:
def _get_child_configs(cls, vllm_config: "VllmConfig") -> list["VllmConfig"]:
"""Return one VllmConfig per child connector, with kv_transfer_config
set to the child's KVTransferConfig. Class resolution is intentionally
omitted here so that callers such as __init__ can delegate it to
KVConnectorFactory.create_connector and avoid resolving the class twice.
"""
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
assert ktcs is not None
ret: list[tuple[type[KVConnectorBaseType], VllmConfig]] = []
ret: list[VllmConfig] = []
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id
)
ret.append(
(
KVConnectorFactory.get_connector_class(
temp_config.kv_transfer_config
),
temp_config,
)
)
ret.append(temp_config)
return ret

@classmethod
def _get_connector_classes_and_configs(
cls, vllm_config: "VllmConfig"
) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]:
return [
(
KVConnectorFactory.get_connector_class(
temp_config.kv_transfer_config # type: ignore[arg-type]
),
temp_config,
)
for temp_config in cls._get_child_configs(vllm_config)
]

def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
Expand Down
Loading