diff --git a/tests/v1/kv_connector/unit/test_backwards_compatibility.py b/tests/v1/kv_connector/unit/test_backwards_compatibility.py index da6a5aadbc6d..89a004bcf232 100644 --- a/tests/v1/kv_connector/unit/test_backwards_compatibility.py +++ b/tests/v1/kv_connector/unit/test_backwards_compatibility.py @@ -273,3 +273,95 @@ def test_signature_detection_with_mocking(): assert isinstance(new_connector, NewStyleTestConnector) assert new_connector._kv_cache_config is not None assert new_connector._kv_cache_config == kv_cache_config + + +# --------------------------------------------------------------------------- +# Regression tests for issue #40690: +# MultiConnector must route child construction through the compat-aware factory +# so that legacy 2-arg connectors work as children, not just at the top level. +# --------------------------------------------------------------------------- + +_THIS_MODULE = "tests.v1.kv_connector.unit.test_backwards_compatibility" + + +@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) +def test_multiconnector_old_style_child_instantiation(role): + """ + Regression test for #40690. + + A legacy connector with the old 2-arg __init__(self, vllm_config, role) + must work as a child of MultiConnector. Before the fix, MultiConnector + directly called connector_cls(config, role, kv_cache_config), which passed + a third positional argument that the old-style connector doesn't accept, + raising TypeError. After the fix, child construction goes through + KVConnectorFactory.create_connector, which applies the same compat shim + that the top-level path uses. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( + MultiConnector, + ) + + vllm_config = create_vllm_config( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [ + { + "kv_connector": "OldStyleTestConnector", + "kv_role": "kv_both", + "kv_connector_module_path": _THIS_MODULE, + } + ] + }, + ) + + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + # This must not raise TypeError before or after the fix. + connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config) + + assert isinstance(connector, MultiConnector) + assert len(connector._connectors) == 1 + child = connector._connectors[0] + assert isinstance(child, OldStyleTestConnector) + # Old-style connector receives no kv_cache_config + assert child._kv_cache_config is None + + +@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) +def test_multiconnector_new_style_child_unaffected(role): + """ + Ensure that the fix does not break new-style 3-arg children inside + MultiConnector. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( + MultiConnector, + ) + + vllm_config = create_vllm_config( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [ + { + "kv_connector": "NewStyleTestConnector", + "kv_role": "kv_both", + "kv_connector_module_path": _THIS_MODULE, + } + ] + }, + ) + + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config) + + assert isinstance(connector, MultiConnector) + assert len(connector._connectors) == 1 + child = connector._connectors[0] + assert isinstance(child, NewStyleTestConnector) + # New-style connector receives kv_cache_config + assert child._kv_cache_config is not None + assert child._kv_cache_config == kv_cache_config diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 3888d2e0f44c..8bc120ac218e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -160,10 +160,12 @@ def __init__( self._connectors: list[KVConnectorBase_V1] = [] self._ktc_kv_transfer_config = [] - for connector_cls, temp_config in self._get_connector_classes_and_configs( - vllm_config - ): - self._connectors.append(connector_cls(temp_config, role, kv_cache_config)) + for temp_config in self._get_child_configs(vllm_config): + connector = KVConnectorFactory.create_connector( + temp_config, role, kv_cache_config + ) + assert isinstance(connector, KVConnectorBase_V1) + self._connectors.append(connector) self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) # A mapping from request id to the index of the connector chosen to @@ -183,31 +185,41 @@ def prefer_cross_layer_blocks(self) -> bool: return all(c.prefer_cross_layer_blocks for c in self._connectors) @classmethod - def _get_connector_classes_and_configs( - cls, vllm_config: "VllmConfig" - ) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]: + def _get_child_configs(cls, vllm_config: "VllmConfig") -> list["VllmConfig"]: + """Return one VllmConfig per child connector, with kv_transfer_config + set to the child's KVTransferConfig. Class resolution is intentionally + omitted here so that callers such as __init__ can delegate it to + KVConnectorFactory.create_connector and avoid resolving the class twice. + """ assert vllm_config.kv_transfer_config is not None ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "connectors" ) assert ktcs is not None - ret: list[tuple[type[KVConnectorBaseType], VllmConfig]] = [] + ret: list[VllmConfig] = [] for ktc in ktcs: temp_config = copy.copy(vllm_config) engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id) temp_config.kv_transfer_config = KVTransferConfig( **ktc, engine_id=engine_id ) - ret.append( - ( - KVConnectorFactory.get_connector_class( - temp_config.kv_transfer_config - ), - temp_config, - ) - ) + ret.append(temp_config) return ret + @classmethod + def _get_connector_classes_and_configs( + cls, vllm_config: "VllmConfig" + ) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]: + return [ + ( + KVConnectorFactory.get_connector_class( + temp_config.kv_transfer_config # type: ignore[arg-type] + ), + temp_config, + ) + for temp_config in cls._get_child_configs(vllm_config) + ] + def register_cross_layers_kv_cache( self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] ):