From 82629e1f628dfdea9d96e554d0d5551440b84f53 Mon Sep 17 00:00:00 2001 From: zixi-qi Date: Tue, 26 May 2026 05:00:08 +0000 Subject: [PATCH 1/5] [KVConnector] Foundation: PP-aware handshake aggregation and intermediate-PP output plumbing Co-authored-by: Claude Signed-off-by: zixi-qi --- .../unit/test_handshake_pp_aggregation.py | 220 ++++++++++++++++++ .../unit/test_pp_intermediate_output.py | 137 +++++++++++ .../unit/test_transfer_topology_sharded.py | 146 ++++++++++++ .../kv_transfer/kv_connector/utils.py | 52 +++-- .../kv_transfer/kv_connector/v1/base.py | 33 +++ .../kv_connector/v1/multi_connector.py | 23 +- vllm/v1/engine/core.py | 24 +- vllm/v1/executor/abstract.py | 2 +- vllm/v1/worker/gpu/model_runner.py | 31 ++- vllm/v1/worker/gpu_worker.py | 17 +- 10 files changed, 657 insertions(+), 28 deletions(-) create mode 100644 tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py create mode 100644 tests/v1/kv_connector/unit/test_pp_intermediate_output.py create mode 100644 tests/v1/kv_connector/unit/test_transfer_topology_sharded.py diff --git a/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py b/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py new file mode 100644 index 000000000000..68f96b35038e --- /dev/null +++ b/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py @@ -0,0 +1,220 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace +from typing import Any + +import pytest + +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorHandshakeMetadata, + SupportsPP, +) +from vllm.v1.engine import core as engine_core_module + +pytestmark = pytest.mark.cpu_test + + +class _Metadata(KVConnectorHandshakeMetadata): + pass + + +class _FakeExecutor: + handshake_metadata_src: ( + list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None + ) + last_instance: "_FakeExecutor | None" = None + + def __init__( + self, + vllm_config: Any, + ) -> None: + del vllm_config + self.handshake_metadata = self.handshake_metadata_src + self.handshake_calls = 0 + self.max_concurrent_batches = 1 + _FakeExecutor.last_instance = self + + def get_kv_connector_handshake_metadata( + self, + ) -> list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None: + self.handshake_calls += 1 + return self.handshake_metadata + + def init_kv_output_aggregator(self, connector: KVConnectorBase_V1) -> None: + pass + + +def _run_engine_core_handshake( + monkeypatch: pytest.MonkeyPatch, + connector: KVConnectorBase_V1, + *, + handshake_metadata: ( + list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None + ), +) -> _FakeExecutor: + class _FakeScheduler: + def __init__(self, **kwargs: Any) -> None: + self.connector = connector + + def get_kv_connector(self) -> KVConnectorBase_V1: + return connector + + _FakeExecutor.handshake_metadata_src = handshake_metadata + _FakeExecutor.last_instance = None + + monkeypatch.setattr("vllm.plugins.load_general_plugins", lambda: None) + monkeypatch.setattr( + engine_core_module.EngineCore, + "_initialize_kv_caches", + lambda self, vllm_config: SimpleNamespace(kv_cache_groups=[object()]), + ) + monkeypatch.setattr( + engine_core_module, + "StructuredOutputManager", + lambda vllm_config: object(), + ) + monkeypatch.setattr( + engine_core_module, + "resolve_kv_cache_block_sizes", + lambda kv_cache_config, vllm_config: (16, 16), + ) + monkeypatch.setattr( + engine_core_module, + "MULTIMODAL_REGISTRY", + SimpleNamespace(engine_receiver_cache_from_config=lambda vllm_config: None), + ) + monkeypatch.setattr(engine_core_module, "freeze_gc_heap", lambda: None) + monkeypatch.setattr( + engine_core_module, "maybe_attach_gc_debug_callback", lambda: None + ) + monkeypatch.setattr(engine_core_module, "enable_envs_cache", lambda: None) + monkeypatch.setattr(engine_core_module, "get_hash_fn_by_name", lambda name: None) + monkeypatch.setattr(engine_core_module, "init_none_hash", lambda hash_fn: None) + monkeypatch.setattr( + engine_core_module, "get_request_block_hasher", lambda *args: None + ) + + vllm_config = SimpleNamespace( + parallel_config=SimpleNamespace(data_parallel_rank_local=0), + scheduler_config=SimpleNamespace( + get_scheduler_cls=lambda: _FakeScheduler, + enable_chunked_prefill=False, + async_scheduling=False, + ), + speculative_config=None, + ec_transfer_config=None, + model_config=SimpleNamespace(runner_type="generate"), + cache_config=SimpleNamespace( + enable_prefix_caching=False, + prefix_caching_hash_algo="builtin", + ), + ) + + engine_core_module.EngineCore(vllm_config, _FakeExecutor, log_stats=False) + assert _FakeExecutor.last_instance is not None + return _FakeExecutor.last_instance + + +class _LegacyConnector(KVConnectorBase_V1): + def __init__(self) -> None: + self.legacy_metadata: dict[int, KVConnectorHandshakeMetadata] | None = None + + def start_load_kv(self, forward_context: Any, **kwargs: Any) -> None: + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: Any, + attn_metadata: Any, + **kwargs: Any, + ) -> None: + pass + + def wait_for_save(self) -> None: + pass + + def get_num_new_matched_tokens( + self, request: Any, num_computed_tokens: int + ) -> tuple[int | None, bool]: + return 0, False + + def update_state_after_alloc( + self, request: Any, blocks: Any, num_external_tokens: int + ) -> None: + pass + + def build_connector_meta(self, scheduler_output: Any) -> Any: + raise NotImplementedError + + def set_xfer_handshake_metadata( + self, metadata: dict[int, KVConnectorHandshakeMetadata] + ) -> None: + self.legacy_metadata = metadata + + +class _PPAwareConnector(_LegacyConnector, SupportsPP): + def __init__(self) -> None: + super().__init__() + self.pp_aware_metadata: ( + dict[tuple[int, int], KVConnectorHandshakeMetadata] | None + ) = None + + def set_xfer_handshake_metadata_pp_aware( + self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata] + ) -> None: + self.pp_aware_metadata = metadata + + +def test_engine_unwraps_handshake_metadata_for_legacy_connector( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Engine core always asks workers for `(pp_rank, tp_rank)`-keyed metadata, + then unwraps to `{tp_rank: metadata}` for a connector that has not opted + into PP-aware handshake. Entries with `pp_rank != 0` are dropped because + the legacy connector cannot consume them.""" + metadata_0 = _Metadata() + metadata_1 = _Metadata() + dropped = _Metadata() + connector = _LegacyConnector() + + executor = _run_engine_core_handshake( + monkeypatch, + connector, + handshake_metadata=[ + {(0, 0): metadata_0}, + None, + {(0, 1): metadata_1, (1, 0): dropped}, + ], + ) + + assert executor.handshake_calls == 1 + assert connector.legacy_metadata == {0: metadata_0, 1: metadata_1} + + +def test_engine_passes_handshake_metadata_through_for_pp_aware_connector( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A PP-aware connector receives the full `(pp_rank, tp_rank)`-keyed dict + unchanged.""" + metadata_0 = _Metadata() + metadata_1 = _Metadata() + connector = _PPAwareConnector() + + executor = _run_engine_core_handshake( + monkeypatch, + connector, + handshake_metadata=[{(0, 0): metadata_0}, {(1, 0): metadata_1}], + ) + + assert executor.handshake_calls == 1 + assert connector.legacy_metadata is None + assert connector.pp_aware_metadata == { + (0, 0): metadata_0, + (1, 0): metadata_1, + } diff --git a/tests/v1/kv_connector/unit/test_pp_intermediate_output.py b/tests/v1/kv_connector/unit/test_pp_intermediate_output.py new file mode 100644 index 000000000000..5d858e151746 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_pp_intermediate_output.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput +from vllm.v1.worker.gpu import model_runner as model_runner_module +from vllm.v1.worker.gpu.model_runner import ExecuteModelState, GPUModelRunner + +pytestmark = pytest.mark.cpu_test + + +def _make_non_last_runner( + kv_connector_output: KVConnectorOutput | None, +) -> GPUModelRunner: + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.execute_model_state = ExecuteModelState( + input_batch=cast(Any, SimpleNamespace(num_reqs=1)), + attn_metadata=None, + slot_mappings_by_layer=None, + hidden_states=None, + aux_hidden_states=None, + finished_req_ids=set(), + ) + runner.kv_connector_output = kv_connector_output + runner.is_last_pp_rank = False + runner.num_speculative_steps = 0 + runner.eplb = SimpleNamespace(step=lambda **_: None) + return runner + + +def _patch_non_last_pp_side_effects(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + model_runner_module, + "pp_receive", + lambda num_reqs, max_sample_len: ([], 0, 0), + ) + monkeypatch.setattr( + GPUModelRunner, + "postprocess", + lambda self, input_batch, sampled, num_sampled, num_rejected: None, + ) + + +def test_non_last_pp_without_kv_output_returns_empty_model_runner_output( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_non_last_pp_side_effects(monkeypatch) + runner = _make_non_last_runner(None) + + output = runner.sample_tokens(None) + + assert output is EMPTY_MODEL_RUNNER_OUTPUT + + +def test_non_last_pp_with_kv_output_returns_copy_carrying_kv_output( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_non_last_pp_side_effects(monkeypatch) + kv_connector_output = KVConnectorOutput(finished_sending={"req-0"}) + runner = _make_non_last_runner(kv_connector_output) + + output = runner.sample_tokens(None) + + assert output is not EMPTY_MODEL_RUNNER_OUTPUT + assert output is not None + assert output.kv_connector_output is kv_connector_output + assert EMPTY_MODEL_RUNNER_OUTPUT.kv_connector_output is None + + +def test_last_pp_rank_still_returns_regular_model_runner_output( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class _FakeAsyncOutput: + def __init__(self, model_runner_output, *args, **kwargs) -> None: + self.model_runner_output = model_runner_output + + def get_output(self): + return self.model_runner_output + + monkeypatch.setattr(model_runner_module, "AsyncOutput", _FakeAsyncOutput) + monkeypatch.setattr( + GPUModelRunner, + "sample", + lambda self, hidden_states, input_batch, grammar_output: ( + SimpleNamespace(sampled_token_ids=[[1]]), + 1, + 0, + ), + ) + monkeypatch.setattr( + GPUModelRunner, + "postprocess", + lambda self, input_batch, sampled, num_sampled, num_rejected: None, + ) + + kv_connector_output = KVConnectorOutput(finished_recving={"req-0"}) + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.execute_model_state = ExecuteModelState( + input_batch=cast(Any, SimpleNamespace(num_reqs=1, req_ids=["req-0"])), + attn_metadata=None, + slot_mappings_by_layer=None, + hidden_states=cast(Any, object()), + aux_hidden_states=None, + finished_req_ids=set(), + ) + runner.kv_connector = SimpleNamespace( + post_forward=lambda finished_req_ids: kv_connector_output, + ) + runner.kv_connector_output = None + runner.is_last_pp_rank = True + runner.use_pp = False + runner.use_async_scheduling = False + runner.speculator = None + runner.main_stream = None + runner.output_copy_stream = None + runner.eplb = SimpleNamespace(step=lambda **_: None) + runner.model = SimpleNamespace(compute_logits=lambda *args, **kwargs: None) + runner.prompt_logprobs_worker = SimpleNamespace( + compute_prompt_logprobs=lambda *args, **kwargs: {} + ) + runner.req_states = SimpleNamespace( + all_token_ids=SimpleNamespace(gpu=None), + num_computed_tokens=SimpleNamespace(gpu=None), + prompt_len=SimpleNamespace(np=None), + prefill_len=SimpleNamespace(np=None), + num_computed_prefill_tokens=None, + ) + + output = runner.sample_tokens(None) + + assert output is not None + assert output.req_ids == ["req-0"] + assert output.kv_connector_output is kv_connector_output diff --git a/tests/v1/kv_connector/unit/test_transfer_topology_sharded.py b/tests/v1/kv_connector/unit/test_transfer_topology_sharded.py new file mode 100644 index 000000000000..ac00bb481284 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_transfer_topology_sharded.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.distributed.kv_transfer.kv_connector.utils import ( + EngineTransferInfo, + TransferTopology, +) + +pytestmark = pytest.mark.cpu_test + + +class _FakeAttentionBackend: + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, int, int, int, int]: + return (2, num_blocks, num_kv_heads, block_size, head_size) + + +def _make_topology( + *, + tp_rank: int = 1, + tp_size: int = 4, + total_num_kv_heads: int = 8, +) -> TransferTopology: + return TransferTopology( + tp_rank=tp_rank, + tp_size=tp_size, + block_size=16, + engine_id="local-engine", + is_mla=False, + is_mamba=False, + total_num_kv_heads=total_num_kv_heads, + attn_backends=[_FakeAttentionBackend], + ) + + +def test_legacy_register_remote_engine_uses_pp_rank_zero() -> None: + topology = _make_topology() + info = EngineTransferInfo( + remote_tp_size=2, + remote_block_len=1024, + remote_block_size=16, + remote_physical_blocks_per_logical=1, + ) + + registered = topology.register_remote_engine("remote-engine", info) + + assert registered == info + assert registered.remote_pp_rank == 0 + assert topology.get_engine_info("remote-engine") == info + assert topology._engines[("remote-engine", 0)] == info + assert topology.target_remote_ranks("remote-engine") == [0] + + +def test_register_remote_engine_stores_pp_ranks_separately() -> None: + topology = _make_topology(tp_rank=0, tp_size=2) + + info_0 = EngineTransferInfo( + remote_tp_size=2, + remote_block_len=1024, + remote_block_size=16, + remote_physical_blocks_per_logical=1, + remote_pp_rank=0, + start_layer=0, + end_layer=16, + ) + info_1 = EngineTransferInfo( + remote_tp_size=1, + remote_block_len=512, + remote_block_size=8, + remote_physical_blocks_per_logical=2, + remote_pp_rank=1, + start_layer=16, + end_layer=32, + ) + + registered_0 = topology.register_remote_engine("remote-engine", info_0) + registered_1 = topology.register_remote_engine("remote-engine", info_1) + + assert registered_0 == info_0 + assert registered_1 == info_1 + assert topology.get_engine_info("remote-engine") == info_0 + assert topology.get_engine_info("remote-engine", 0) == info_0 + assert topology.get_engine_info("remote-engine", 1) == info_1 + assert set(topology._engines) == { + ("remote-engine", 0), + ("remote-engine", 1), + } + + +def test_helpers_use_requested_pp_rank() -> None: + topology = _make_topology(tp_rank=1, tp_size=2, total_num_kv_heads=2) + topology.register_remote_engine( + "remote-engine", + EngineTransferInfo( + remote_tp_size=1, + remote_block_len=1024, + remote_block_size=16, + remote_physical_blocks_per_logical=1, + remote_pp_rank=0, + start_layer=0, + end_layer=8, + ), + ) + topology.register_remote_engine( + "remote-engine", + EngineTransferInfo( + remote_tp_size=4, + remote_block_len=1024, + remote_block_size=16, + remote_physical_blocks_per_logical=1, + remote_pp_rank=1, + start_layer=8, + end_layer=16, + ), + ) + + assert not topology.is_kv_replicated("remote-engine", 0) + assert topology.is_kv_replicated("remote-engine", 1) + assert topology.replicates_kv_cache("remote-engine", 1) + assert topology.target_remote_ranks("remote-engine", 0) == [0] + assert topology.target_remote_ranks("remote-engine", 1) == [2, 3] + assert "remote_pp=1" in topology.describe("remote-engine", 1) + + +def test_engine_info_fields_have_backward_compatible_defaults() -> None: + topology = _make_topology() + info = EngineTransferInfo( + remote_tp_size=2, + remote_block_len=1024, + remote_block_size=16, + remote_physical_blocks_per_logical=1, + ) + + registered = topology.register_remote_engine("remote-engine", info) + + assert topology.get_engine_info("remote-engine") == registered + assert registered.remote_pp_rank == 0 + assert registered.start_layer == 0 + assert registered.end_layer == 0 diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index d7a595716f08..82a4f9ac724e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -371,7 +371,7 @@ def get_current_attn_backend( class EngineTransferInfo: """Common per-remote-engine transfer state, computed at handshake. - Stored per ``engine_id`` inside ``TransferTopology._engines``. + Stored per ``(engine_id, pp_rank)`` inside ``TransferTopology._engines``. """ remote_tp_size: int @@ -385,6 +385,15 @@ class EngineTransferInfo: remote_physical_blocks_per_logical: int """Physical blocks per logical block.""" + remote_pp_rank: int = 0 + """Remote producer PP rank for this engine.""" + + start_layer: int = 0 + """Global index of the first layer owned by this PP rank.""" + + end_layer: int = 0 + """Exclusive global index after the last layer owned by this PP rank.""" + # ---- Transfer topology ---- @@ -406,7 +415,7 @@ class TransferTopology: def __post_init__(self): self.local_physical_heads = max(1, self.total_num_kv_heads // self.tp_size) - self._engines: dict[EngineId, EngineTransferInfo] = {} + self._engines: dict[tuple[EngineId, int], EngineTransferInfo] = {} # Figure out whether the first dimension of the cache is K/V # or num_blocks. @@ -464,13 +473,16 @@ def register_remote_engine( f"Cannot register local engine {self.engine_id} as remote. " f"Local identity is set via __init__ params." ) - if remote_engine_id in self._engines: - return self._engines[remote_engine_id] - self._engines[remote_engine_id] = info + engine_key = (remote_engine_id, info.remote_pp_rank) + if engine_key in self._engines: + return self._engines[engine_key] + self._engines[engine_key] = info return info - def get_engine_info(self, remote_engine_id: EngineId) -> EngineTransferInfo: - return self._engines[remote_engine_id] + def get_engine_info( + self, remote_engine_id: EngineId, remote_pp_rank: int = 0 + ) -> EngineTransferInfo: + return self._engines[(remote_engine_id, remote_pp_rank)] # ============================================================ # Layout properties @@ -531,15 +543,22 @@ def block_size_ratio(self, remote_block_size: int) -> int: ) return self.block_size // remote_block_size - def is_kv_replicated(self, remote_engine_id: EngineId) -> bool: + def is_kv_replicated( + self, remote_engine_id: EngineId, remote_pp_rank: int = 0 + ) -> bool: """Whether the KV cache is replicated across TP workers due to the number of TP workers being greater than the number of KV heads. """ - return self._engines[remote_engine_id].remote_tp_size > self.total_num_kv_heads + return ( + self._engines[(remote_engine_id, remote_pp_rank)].remote_tp_size + > self.total_num_kv_heads + ) - def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: + def replicates_kv_cache( + self, remote_engine_id: EngineId, remote_pp_rank: int = 0 + ) -> bool: # MLA is always replicated as the hidden dim can't be split. - return self.is_mla or self.is_kv_replicated(remote_engine_id) + return self.is_mla or self.is_kv_replicated(remote_engine_id, remote_pp_rank) @property def local_replicates_kv_cache(self) -> bool: @@ -558,12 +577,14 @@ def handshake_target_ranks(self, remote_tp_size: int) -> list[int]: abs_ratio = -tp_ratio return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)] - def target_remote_ranks(self, remote_engine_id: EngineId) -> list[int]: + def target_remote_ranks( + self, remote_engine_id: EngineId, remote_pp_rank: int = 0 + ) -> list[int]: """Get the remote TP rank(s) that the current local TP rank will read from. When remote tp_size > local tp_size, reads from multiple remote ranks. """ - info = self._engines[remote_engine_id] + info = self._engines[(remote_engine_id, remote_pp_rank)] tp_ratio = self.tp_ratio(info.remote_tp_size) if tp_ratio > 0: return [self.tp_rank // tp_ratio] @@ -596,15 +617,16 @@ def get_transfer_cache_regions( # Regular case: backends like FA register K/V in separate regions return cache if self.split_k_and_v else [cache] - def describe(self, remote_engine_id: EngineId) -> str: + def describe(self, remote_engine_id: EngineId, remote_pp_rank: int = 0) -> str: """One-line summary of transfer config for logging.""" - info = self._engines[remote_engine_id] + info = self._engines[(remote_engine_id, remote_pp_rank)] return ( f"TransferTopology(" f"tp_ratio={self.tp_ratio(info.remote_tp_size)}, " f"num_kv_heads={self.total_num_kv_heads if not self.is_mla else 1}, " f"local_tp={self.tp_size}, " f"remote_tp={info.remote_tp_size}, " + f"remote_pp={remote_pp_rank}, " f"local_rank={self.tp_rank}, " f"remote_block_len={info.remote_block_len})" ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index fb5658da887a..cfb633171b95 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -121,6 +121,39 @@ def supports_hma(connector: Any) -> bool: return isinstance(connector, SupportsHMA) +class SupportsPP(ABC): + """ + The class that indicates the corresponding connector supports + pipeline-parallel (PP) disaggregated serving. + + Connectors that inherit from this class receive KV connector handshake + metadata keyed by ``(pp_rank, tp_rank)`` instead of just ``tp_rank``, so + they can track per-PP-rank remote state. + """ + + @abstractmethod + def set_xfer_handshake_metadata_pp_aware( + self, metadata: dict[tuple[int, int], "KVConnectorHandshakeMetadata"] + ) -> None: + """ + Set PP-aware KV connector handshake metadata for this connector. + + NOTE: This function is only supported by connectors that support PP + disaggregation (inherit from ``SupportsPP``). Connectors that do not + inherit from ``SupportsPP`` keep receiving ``{tp_rank: metadata}`` + via ``set_xfer_handshake_metadata``; engine core unwraps the + tuple-keyed dict at the dispatch site. + """ + raise NotImplementedError + + +def supports_pp(connector: Any) -> bool: + if isinstance(connector, type): + return issubclass(connector, SupportsPP) + else: + return isinstance(connector, SupportsPP) + + class KVConnectorRole(enum.Enum): # Connector running in the scheduler process SCHEDULER = 0 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 73418104bea6..a0316e081e06 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -19,6 +19,8 @@ KVConnectorRole, KVConnectorWorkerMetadata, SupportsHMA, + SupportsPP, + supports_pp, ) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, @@ -125,7 +127,7 @@ def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx) -class MultiConnector(KVConnectorBase_V1, SupportsHMA): +class MultiConnector(KVConnectorBase_V1, SupportsHMA, SupportsPP): """ A wrapper for using multiple KVConnectors at the same time. @@ -471,6 +473,25 @@ def set_xfer_handshake_metadata( for c in self._connectors: c.set_xfer_handshake_metadata(metadata) + def set_xfer_handshake_metadata_pp_aware( + self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata] + ) -> None: + """ + Set PP-aware KV connector handshake metadata for all sub-connectors. + Non-``SupportsPP`` children receive the ``{tp_rank: metadata}`` shape + via a ``(0, tp_rank)`` unwrap. + """ + for c in self._connectors: + if supports_pp(c): + cast(SupportsPP, c).set_xfer_handshake_metadata_pp_aware(metadata) + else: + non_pp_metadata = { + tp_rank: meta + for (pp_rank, tp_rank), meta in metadata.items() + if pp_rank == 0 + } + c.set_xfer_handshake_metadata(non_pp_metadata) + def _aggregate_request_finished( self, request: "Request", diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c21a4de5d309..05697a3991cc 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -26,6 +26,10 @@ cleanup_dist_env_and_memory, stateless_destroy_torch_distributed_process_group, ) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + SupportsPP, + supports_pp, +) from vllm.envs import enable_envs_cache from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception @@ -176,13 +180,27 @@ def __init__( if xfer_handshake_metadata: # xfer_handshake_metadata is list of dicts from workers - # Each dict already has structure {tp_rank: metadata} + # Each dict already has structure {(pp_rank, tp_rank): metadata} # Merge all worker dicts into a single dict - content: dict[int, Any] = {} + content: dict[tuple[int, int], Any] = {} for worker_dict in xfer_handshake_metadata: if worker_dict is not None: content.update(worker_dict) - kv_connector.set_xfer_handshake_metadata(content) + + if supports_pp(kv_connector): + cast(SupportsPP, kv_connector).set_xfer_handshake_metadata_pp_aware( + content + ) + else: + # Unwrap (0, tp_rank) entries to {tp_rank: metadata} for + # non-PP connectors. Entries with pp_rank != 0 cannot be + # consumed by a non-PP connector and are dropped. + non_pp_content: dict[int, Any] = { + tp_rank: meta + for (pp_rank, tp_rank), meta in content.items() + if pp_rank == 0 + } + kv_connector.set_xfer_handshake_metadata(non_pp_content) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index e68c0283f579..88a3a253d4a1 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -203,7 +203,7 @@ def collective_rpc( def get_kv_connector_handshake_metadata( self, - ) -> list[dict[int, KVConnectorHandshakeMetadata]]: + ) -> list[dict[tuple[int, int], KVConnectorHandshakeMetadata]]: return self.collective_rpc("get_kv_connector_handshake_metadata") @overload diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 5cba66e5c9f9..0e618ae71142 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -17,6 +17,7 @@ instead of embedding feature-specific logic directly. """ +import copy import functools import gc import time @@ -50,7 +51,11 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec -from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + DraftTokenIds, + ModelRunnerOutput, +) from vllm.v1.worker.cp_utils import check_attention_cp_compatibility from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput from vllm.v1.worker.gpu.attn_utils import ( @@ -244,6 +249,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # For transferring state from execute_model to subsequent sample_tokens call. self.execute_model_state: ExecuteModelState | None = None + # Non-last-PP-rank kv_connector_output is computed in execute_model and + # consumed by sample_tokens so the aggregator can count this rank's + # finished_sending / finished_recving ticks. + self.kv_connector_output: Any = None # Expert parallelism load balancer. self.eplb = EPLBController(self.parallel_config, self.device) @@ -1218,6 +1227,10 @@ def execute_model( if not self.is_last_pp_rank: # Non-last PP rank: return IntermediateTensors for sending. + assert output_intermediate_tensors is not None + kv_connector_output = self.kv_connector.post_forward(finished_req_ids) + output_intermediate_tensors.kv_connector_output = kv_connector_output + self.kv_connector_output = kv_connector_output return output_intermediate_tensors return None @@ -1246,10 +1259,18 @@ def sample_tokens( input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1 ) self.postprocess(input_batch, sampled, num_sampled, num_rejected) - - # Post-step KV connector related operations. - kv_connector_output = self.kv_connector.post_forward(finished_req_ids) - return ModelRunnerOutput.with_kv_conn_output_only(kv_connector_output) + # Return an empty ModelRunnerOutput carrying the local + # kv_connector_output so KVOutputAggregator.aggregate can + # count this rank's finished_sending / finished_recving ticks. + # Returning None drops the rank's per-request completion signal + # and trips the aggregator's non-None invariant. + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + if kv_connector_output is None or kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output # Last rank: sample tokens sampler_output, num_sampled, num_rejected = self.sample( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index e63f50bc8dc2..5832fe0c9ae8 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -34,6 +34,9 @@ get_kv_transfer_group, has_kv_transfer_group, ) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorHandshakeMetadata, +) from vllm.distributed.parallel_state import ( Handle, get_pp_group, @@ -512,8 +515,15 @@ def determine_available_memory(self) -> int: return int(self.available_kv_cache_memory_bytes) - def get_kv_connector_handshake_metadata(self) -> dict | None: - """Get KV connector metadata from this worker if available.""" + def get_kv_connector_handshake_metadata( + self, + ) -> dict[tuple[int, int], KVConnectorHandshakeMetadata] | None: + """Get KV connector metadata from this worker if available. + + Returned dict is keyed by `(pp_rank, tp_rank)`. Engine core unwraps + the tuple keys for connectors that do not declare PP-aware support + (see ``KVConnectorBase_V1.supports_pp_aware_handshake``). + """ if not has_kv_transfer_group(): return None @@ -524,8 +534,9 @@ def get_kv_connector_handshake_metadata(self) -> dict | None: if (metadata := connector.get_handshake_metadata()) is None: return None + pp_rank = get_pp_group().rank_in_group tp_rank = get_tp_group().rank_in_group - return {tp_rank: metadata} + return {(pp_rank, tp_rank): metadata} def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() From c1336b2a2998863a2eb56ecfbf36e58b136cdb28 Mon Sep 17 00:00:00 2001 From: zixi-qi Date: Fri, 29 May 2026 06:16:42 +0000 Subject: [PATCH 2/5] [KVConnector] Rebase onto main: drop intermediate-PP output plumbing superseded by #43732 Co-authored-by: Claude Signed-off-by: zixi-qi --- .../unit/test_pp_intermediate_output.py | 137 ------------------ vllm/v1/worker/gpu/model_runner.py | 31 +--- 2 files changed, 5 insertions(+), 163 deletions(-) delete mode 100644 tests/v1/kv_connector/unit/test_pp_intermediate_output.py diff --git a/tests/v1/kv_connector/unit/test_pp_intermediate_output.py b/tests/v1/kv_connector/unit/test_pp_intermediate_output.py deleted file mode 100644 index 5d858e151746..000000000000 --- a/tests/v1/kv_connector/unit/test_pp_intermediate_output.py +++ /dev/null @@ -1,137 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from types import SimpleNamespace -from typing import Any, cast - -import pytest - -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput -from vllm.v1.worker.gpu import model_runner as model_runner_module -from vllm.v1.worker.gpu.model_runner import ExecuteModelState, GPUModelRunner - -pytestmark = pytest.mark.cpu_test - - -def _make_non_last_runner( - kv_connector_output: KVConnectorOutput | None, -) -> GPUModelRunner: - runner = GPUModelRunner.__new__(GPUModelRunner) - runner.execute_model_state = ExecuteModelState( - input_batch=cast(Any, SimpleNamespace(num_reqs=1)), - attn_metadata=None, - slot_mappings_by_layer=None, - hidden_states=None, - aux_hidden_states=None, - finished_req_ids=set(), - ) - runner.kv_connector_output = kv_connector_output - runner.is_last_pp_rank = False - runner.num_speculative_steps = 0 - runner.eplb = SimpleNamespace(step=lambda **_: None) - return runner - - -def _patch_non_last_pp_side_effects(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr( - model_runner_module, - "pp_receive", - lambda num_reqs, max_sample_len: ([], 0, 0), - ) - monkeypatch.setattr( - GPUModelRunner, - "postprocess", - lambda self, input_batch, sampled, num_sampled, num_rejected: None, - ) - - -def test_non_last_pp_without_kv_output_returns_empty_model_runner_output( - monkeypatch: pytest.MonkeyPatch, -) -> None: - _patch_non_last_pp_side_effects(monkeypatch) - runner = _make_non_last_runner(None) - - output = runner.sample_tokens(None) - - assert output is EMPTY_MODEL_RUNNER_OUTPUT - - -def test_non_last_pp_with_kv_output_returns_copy_carrying_kv_output( - monkeypatch: pytest.MonkeyPatch, -) -> None: - _patch_non_last_pp_side_effects(monkeypatch) - kv_connector_output = KVConnectorOutput(finished_sending={"req-0"}) - runner = _make_non_last_runner(kv_connector_output) - - output = runner.sample_tokens(None) - - assert output is not EMPTY_MODEL_RUNNER_OUTPUT - assert output is not None - assert output.kv_connector_output is kv_connector_output - assert EMPTY_MODEL_RUNNER_OUTPUT.kv_connector_output is None - - -def test_last_pp_rank_still_returns_regular_model_runner_output( - monkeypatch: pytest.MonkeyPatch, -) -> None: - class _FakeAsyncOutput: - def __init__(self, model_runner_output, *args, **kwargs) -> None: - self.model_runner_output = model_runner_output - - def get_output(self): - return self.model_runner_output - - monkeypatch.setattr(model_runner_module, "AsyncOutput", _FakeAsyncOutput) - monkeypatch.setattr( - GPUModelRunner, - "sample", - lambda self, hidden_states, input_batch, grammar_output: ( - SimpleNamespace(sampled_token_ids=[[1]]), - 1, - 0, - ), - ) - monkeypatch.setattr( - GPUModelRunner, - "postprocess", - lambda self, input_batch, sampled, num_sampled, num_rejected: None, - ) - - kv_connector_output = KVConnectorOutput(finished_recving={"req-0"}) - runner = GPUModelRunner.__new__(GPUModelRunner) - runner.execute_model_state = ExecuteModelState( - input_batch=cast(Any, SimpleNamespace(num_reqs=1, req_ids=["req-0"])), - attn_metadata=None, - slot_mappings_by_layer=None, - hidden_states=cast(Any, object()), - aux_hidden_states=None, - finished_req_ids=set(), - ) - runner.kv_connector = SimpleNamespace( - post_forward=lambda finished_req_ids: kv_connector_output, - ) - runner.kv_connector_output = None - runner.is_last_pp_rank = True - runner.use_pp = False - runner.use_async_scheduling = False - runner.speculator = None - runner.main_stream = None - runner.output_copy_stream = None - runner.eplb = SimpleNamespace(step=lambda **_: None) - runner.model = SimpleNamespace(compute_logits=lambda *args, **kwargs: None) - runner.prompt_logprobs_worker = SimpleNamespace( - compute_prompt_logprobs=lambda *args, **kwargs: {} - ) - runner.req_states = SimpleNamespace( - all_token_ids=SimpleNamespace(gpu=None), - num_computed_tokens=SimpleNamespace(gpu=None), - prompt_len=SimpleNamespace(np=None), - prefill_len=SimpleNamespace(np=None), - num_computed_prefill_tokens=None, - ) - - output = runner.sample_tokens(None) - - assert output is not None - assert output.req_ids == ["req-0"] - assert output.kv_connector_output is kv_connector_output diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 0e618ae71142..5cba66e5c9f9 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -17,7 +17,6 @@ instead of embedding feature-specific logic directly. """ -import copy import functools import gc import time @@ -51,11 +50,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec -from vllm.v1.outputs import ( - EMPTY_MODEL_RUNNER_OUTPUT, - DraftTokenIds, - ModelRunnerOutput, -) +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.worker.cp_utils import check_attention_cp_compatibility from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput from vllm.v1.worker.gpu.attn_utils import ( @@ -249,10 +244,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # For transferring state from execute_model to subsequent sample_tokens call. self.execute_model_state: ExecuteModelState | None = None - # Non-last-PP-rank kv_connector_output is computed in execute_model and - # consumed by sample_tokens so the aggregator can count this rank's - # finished_sending / finished_recving ticks. - self.kv_connector_output: Any = None # Expert parallelism load balancer. self.eplb = EPLBController(self.parallel_config, self.device) @@ -1227,10 +1218,6 @@ def execute_model( if not self.is_last_pp_rank: # Non-last PP rank: return IntermediateTensors for sending. - assert output_intermediate_tensors is not None - kv_connector_output = self.kv_connector.post_forward(finished_req_ids) - output_intermediate_tensors.kv_connector_output = kv_connector_output - self.kv_connector_output = kv_connector_output return output_intermediate_tensors return None @@ -1259,18 +1246,10 @@ def sample_tokens( input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1 ) self.postprocess(input_batch, sampled, num_sampled, num_rejected) - # Return an empty ModelRunnerOutput carrying the local - # kv_connector_output so KVOutputAggregator.aggregate can - # count this rank's finished_sending / finished_recving ticks. - # Returning None drops the rank's per-request completion signal - # and trips the aggregator's non-None invariant. - kv_connector_output = self.kv_connector_output - self.kv_connector_output = None - if kv_connector_output is None or kv_connector_output.is_empty(): - return EMPTY_MODEL_RUNNER_OUTPUT - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.kv_connector_output = kv_connector_output - return output + + # Post-step KV connector related operations. + kv_connector_output = self.kv_connector.post_forward(finished_req_ids) + return ModelRunnerOutput.with_kv_conn_output_only(kv_connector_output) # Last rank: sample tokens sampler_output, num_sampled, num_rejected = self.sample( From 29491ee2df9b7bc6f745d103e7fcebc216548a26 Mon Sep 17 00:00:00 2001 From: zixi-qi Date: Sat, 30 May 2026 22:24:54 +0000 Subject: [PATCH 3/5] [KVConnector] Remove SupportsPP marker; default pp-aware handshake in base Co-authored-by: Claude Signed-off-by: zixi-qi --- .../unit/test_handshake_pp_aggregation.py | 3 +- .../kv_transfer/kv_connector/v1/base.py | 53 +++++++------------ .../kv_connector/v1/multi_connector.py | 19 ++----- vllm/v1/engine/core.py | 23 ++------ 4 files changed, 31 insertions(+), 67 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py b/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py index 68f96b35038e..b0f4e94d7068 100644 --- a/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py +++ b/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py @@ -9,7 +9,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorHandshakeMetadata, - SupportsPP, ) from vllm.v1.engine import core as engine_core_module @@ -158,7 +157,7 @@ def set_xfer_handshake_metadata( self.legacy_metadata = metadata -class _PPAwareConnector(_LegacyConnector, SupportsPP): +class _PPAwareConnector(_LegacyConnector): def __init__(self) -> None: super().__init__() self.pp_aware_metadata: ( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index cfb633171b95..9eaa4aaa4b00 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -121,39 +121,6 @@ def supports_hma(connector: Any) -> bool: return isinstance(connector, SupportsHMA) -class SupportsPP(ABC): - """ - The class that indicates the corresponding connector supports - pipeline-parallel (PP) disaggregated serving. - - Connectors that inherit from this class receive KV connector handshake - metadata keyed by ``(pp_rank, tp_rank)`` instead of just ``tp_rank``, so - they can track per-PP-rank remote state. - """ - - @abstractmethod - def set_xfer_handshake_metadata_pp_aware( - self, metadata: dict[tuple[int, int], "KVConnectorHandshakeMetadata"] - ) -> None: - """ - Set PP-aware KV connector handshake metadata for this connector. - - NOTE: This function is only supported by connectors that support PP - disaggregation (inherit from ``SupportsPP``). Connectors that do not - inherit from ``SupportsPP`` keep receiving ``{tp_rank: metadata}`` - via ``set_xfer_handshake_metadata``; engine core unwraps the - tuple-keyed dict at the dispatch site. - """ - raise NotImplementedError - - -def supports_pp(connector: Any) -> bool: - if isinstance(connector, type): - return issubclass(connector, SupportsPP) - else: - return isinstance(connector, SupportsPP) - - class KVConnectorRole(enum.Enum): # Connector running in the scheduler process SCHEDULER = 0 @@ -677,6 +644,26 @@ def set_xfer_handshake_metadata( """ return None + def set_xfer_handshake_metadata_pp_aware( + self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata] + ) -> None: + """ + Set handshake metadata keyed by ``(pp_rank, tp_rank)``. + + Default implementation supports only single-PP-rank producers: it keeps + the ``pp_rank == 0`` entries, drops the rest, and forwards the resulting + ``{tp_rank: metadata}`` to ``set_xfer_handshake_metadata``. Connectors + that support PP-disaggregated transfer override this to consume metadata + from all PP producer shards. + """ + self.set_xfer_handshake_metadata( + { + tp_rank: meta + for (pp_rank, tp_rank), meta in metadata.items() + if pp_rank == 0 + } + ) + @classmethod def build_prom_metrics( cls, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index a0316e081e06..f43aa36bb725 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -19,8 +19,6 @@ KVConnectorRole, KVConnectorWorkerMetadata, SupportsHMA, - SupportsPP, - supports_pp, ) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, @@ -127,7 +125,7 @@ def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx) -class MultiConnector(KVConnectorBase_V1, SupportsHMA, SupportsPP): +class MultiConnector(KVConnectorBase_V1, SupportsHMA): """ A wrapper for using multiple KVConnectors at the same time. @@ -478,19 +476,12 @@ def set_xfer_handshake_metadata_pp_aware( ) -> None: """ Set PP-aware KV connector handshake metadata for all sub-connectors. - Non-``SupportsPP`` children receive the ``{tp_rank: metadata}`` shape - via a ``(0, tp_rank)`` unwrap. + Each child consumes the ``(pp_rank, tp_rank)``-keyed dict via its own + ``set_xfer_handshake_metadata_pp_aware``; children that do not support + PP fall back to the base default (``pp_rank == 0`` unwrap). """ for c in self._connectors: - if supports_pp(c): - cast(SupportsPP, c).set_xfer_handshake_metadata_pp_aware(metadata) - else: - non_pp_metadata = { - tp_rank: meta - for (pp_rank, tp_rank), meta in metadata.items() - if pp_rank == 0 - } - c.set_xfer_handshake_metadata(non_pp_metadata) + c.set_xfer_handshake_metadata_pp_aware(metadata) def _aggregate_request_finished( self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 05697a3991cc..4d3193fef12b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -26,10 +26,6 @@ cleanup_dist_env_and_memory, stateless_destroy_torch_distributed_process_group, ) -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - SupportsPP, - supports_pp, -) from vllm.envs import enable_envs_cache from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception @@ -187,20 +183,11 @@ def __init__( if worker_dict is not None: content.update(worker_dict) - if supports_pp(kv_connector): - cast(SupportsPP, kv_connector).set_xfer_handshake_metadata_pp_aware( - content - ) - else: - # Unwrap (0, tp_rank) entries to {tp_rank: metadata} for - # non-PP connectors. Entries with pp_rank != 0 cannot be - # consumed by a non-PP connector and are dropped. - non_pp_content: dict[int, Any] = { - tp_rank: meta - for (pp_rank, tp_rank), meta in content.items() - if pp_rank == 0 - } - kv_connector.set_xfer_handshake_metadata(non_pp_content) + # Connectors receive metadata keyed by (pp_rank, tp_rank). The + # base method default keeps only pp_rank==0 and forwards + # {tp_rank: metadata} to set_xfer_handshake_metadata; PP-aware + # connectors override it to consume all PP producer shards. + kv_connector.set_xfer_handshake_metadata_pp_aware(content) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously From b2da7722d0a02660ccfcd96787c5dbea93424fcc Mon Sep 17 00:00:00 2001 From: zixi-qi Date: Wed, 27 May 2026 01:37:45 +0000 Subject: [PATCH 4/5] [KVConnector][NIXL] Enable PP-disaggregated KV transfer (single-group, no HMA) Signed-off-by: zixi-qi --- .../nixl_side_channel_probe.py | 5 +- .../test_consumer_shard_refactor.py | 235 +++++ .../test_handshake_aggregation.py | 245 +++++ .../nixl_integration/test_pp_layer_map.py | 172 ++++ .../unit/test_bidirectional_kv_transfer.py | 1 + .../kv_connector/unit/test_nixl_connector.py | 176 +++- .../unit/test_nixl_connector_hma.py | 31 +- tests/v1/kv_connector/unit/utils.py | 4 +- .../kv_connector/v1/nixl/__init__.py | 4 + .../kv_connector/v1/nixl/connector.py | 10 +- .../kv_connector/v1/nixl/metadata.py | 16 +- .../kv_connector/v1/nixl/pp_layer_map.py | 71 ++ .../kv_connector/v1/nixl/scheduler.py | 30 +- .../kv_connector/v1/nixl/worker.py | 864 ++++++++++++------ 14 files changed, 1535 insertions(+), 329 deletions(-) create mode 100644 tests/v1/kv_connector/nixl_integration/test_consumer_shard_refactor.py create mode 100644 tests/v1/kv_connector/nixl_integration/test_handshake_aggregation.py create mode 100644 tests/v1/kv_connector/nixl_integration/test_pp_layer_map.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/nixl/pp_layer_map.py diff --git a/tests/v1/kv_connector/nixl_integration/nixl_side_channel_probe.py b/tests/v1/kv_connector/nixl_integration/nixl_side_channel_probe.py index 24ecbd795e41..4d1e84103411 100644 --- a/tests/v1/kv_connector/nixl_integration/nixl_side_channel_probe.py +++ b/tests/v1/kv_connector/nixl_integration/nixl_side_channel_probe.py @@ -15,7 +15,8 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--host", required=True) parser.add_argument("--port", required=True, type=int) - parser.add_argument("--rank", default=0, type=int) + parser.add_argument("--pp-rank", default=0, type=int) + parser.add_argument("--tp-rank", default=0, type=int) parser.add_argument("--timeout-ms", default=1000, type=int) return parser.parse_args() @@ -37,7 +38,7 @@ def main() -> None: sock.setsockopt(zmq.RCVTIMEO, args.timeout_ms) try: sock.connect(make_zmq_path(args.host, args.port)) - sock.send(msgspec.msgpack.encode((GET_META_MSG, args.rank))) + sock.send(msgspec.msgpack.encode((GET_META_MSG, args.pp_rank, args.tp_rank))) sock.recv() finally: sock.close() diff --git a/tests/v1/kv_connector/nixl_integration/test_consumer_shard_refactor.py b/tests/v1/kv_connector/nixl_integration/test_consumer_shard_refactor.py new file mode 100644 index 000000000000..50fc310c74f5 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/test_consumer_shard_refactor.py @@ -0,0 +1,235 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import defaultdict +from types import SimpleNamespace + +from tests.v1.kv_connector.nixl_integration.test_pp_layer_map import ( + _FakeAttentionBackend, + _meta, +) +from vllm.distributed.kv_transfer.kv_connector.utils import TransferTopology +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( + NixlAgentMetadata, + RemoteMeta, + ReqMeta, +) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( + NixlConnectorWorker, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec + +REMOTE_ENGINE_ID = "engine" +LOCAL_ENGINE_ID = "local-engine" + + +class _FakeNixlWrapper: + def __init__(self) -> None: + self._next_handle = 0 + + def add_remote_agent(self, agent_metadata: bytes) -> str: + return f"remote-agent-{len(agent_metadata)}-{self._next_handle}" + + def get_xfer_descs(self, blocks_data, memory_type: str) -> list[object]: + return list(blocks_data) + + def prep_xfer_dlist(self, agent_name: str, descs: list[object]) -> int: + self._next_handle += 1 + return self._next_handle + + def make_prepped_xfer(self, *args, **kwargs) -> int: + self._next_handle += 1 + return self._next_handle + + def transfer(self, handle: int) -> None: + pass + + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: + pass + + +def _make_worker(total_layers: int = 32) -> NixlConnectorWorker: + worker = NixlConnectorWorker.__new__(NixlConnectorWorker) + worker.engine_id = LOCAL_ENGINE_ID + worker.tp_rank = 0 + worker.world_size = 1 + worker.block_size = 16 + worker.num_blocks = 4 + worker._logical_num_blocks = worker.num_blocks + worker._physical_blocks_per_logical_kv_block = 1 + worker.use_mla = False + worker._has_mamba = False + worker._is_hma_required = False + worker._group_spec_types = (FullAttentionSpec,) + worker.kv_cache_config = SimpleNamespace(kv_cache_groups=[object()]) + worker.transfer_topo = TransferTopology( + tp_rank=worker.tp_rank, + tp_size=worker.world_size, + block_size=worker.block_size, + engine_id=worker.engine_id, + is_mla=worker.use_mla, + is_mamba=False, + total_num_kv_heads=8, + attn_backends=[_FakeAttentionBackend], + tensor_shape=_FakeAttentionBackend.get_kv_cache_shape( + num_blocks=1, + block_size=worker.block_size, + num_kv_heads=1, + head_size=1, + ), + ) + worker.model_config = SimpleNamespace( + get_total_num_hidden_layers=lambda: total_layers + ) + worker.kv_cache_layout = "HND" + worker.host_buffer_kv_cache_layout = "HND" + worker.use_host_buffer = False + worker.kv_transfer_config = SimpleNamespace(enable_permute_local_kv=False) + worker.backend_name = "FLASH_ATTN" + worker.block_len_per_layer = [1024] * total_layers + worker.local_registered_layer_indices = list(range(total_layers)) + worker.local_seen_layer_names = [ + f"model.layers.{layer}.self_attn" for layer in range(total_layers) + ] + worker._layer_name_to_kv_group_index = { + layer_name: 0 for layer_name in worker.local_seen_layer_names + } + worker.device_id = 0 + worker.nixl_memory_type = "DRAM" + worker.nixl_wrapper = _FakeNixlWrapper() + worker.kv_caches_base_addr = defaultdict(dict) + worker._local_kv_cache_key = (0, worker.tp_rank) + worker.kv_caches_base_addr[worker.engine_id][worker._local_kv_cache_key] = [ + 100_000 + layer * 10_000 for layer in range(total_layers) + ] + worker._remote_agents = defaultdict(dict) + worker._remote_agent_metadata = defaultdict(dict) + worker._pp_layer_map = {} + worker.src_xfer_handles_by_remote = {} + worker.src_blocks_data_by_remote = {} + worker.src_xfer_handles_by_shard_tp_ratio = {} + worker.dst_xfer_side_handles = defaultdict(dict) + worker._xfer_desc_layouts = {} + worker.tp_mappings = {} + worker.dst_num_blocks = {} + worker._physical_blocks_per_logical = {} + worker._recving_transfers = defaultdict(list) + worker._recving_metadata = {} + worker._failed_recv_reqs = set() + worker._invalid_block_ids = set() + worker.enable_permute_local_kv = False + worker.enable_heterogeneous_attn_post_process = False + worker.xfer_stats = SimpleNamespace( + record_failed_notification=lambda: None, + record_failed_transfer=lambda: None, + ) + return worker + + +def _remote_meta( + worker: NixlConnectorWorker, + pp_rank: int, + start_layer: int, + end_layer: int, + *, + pp_size: int, +) -> NixlAgentMetadata: + meta = _meta( + pp_rank, + start_layer, + end_layer, + pp_size=pp_size, + registered_layer_indices=list(range(start_layer, end_layer)), + ) + meta.num_blocks = worker.num_blocks + meta.attn_backend_name = worker.backend_name + meta.kv_caches_base_addr = [ + 200_000 + layer * 10_000 for layer in range(start_layer, end_layer) + ] + return meta + + +def _add_two_remote_shards(worker: NixlConnectorWorker) -> list[NixlAgentMetadata]: + metas = [ + _remote_meta(worker, 0, 0, 16, pp_size=2), + _remote_meta(worker, 1, 16, 32, pp_size=2), + ] + for meta in metas: + worker.add_remote_agent( + meta, + remote_tp_rank=0, + remote_tp_size=1, + remote_pp_rank=meta.pp_rank, + remote_pp_size=meta.pp_size, + ) + return metas + + +def test_add_remote_agent_records_both_pp_shard_base_address_keys(): + worker = _make_worker() + + _add_two_remote_shards(worker) + + assert set(worker.kv_caches_base_addr[REMOTE_ENGINE_ID]) == {(0, 0), (1, 0)} + + +def test_validate_remote_agent_handshake_accepts_synthetic_pp_shard(): + worker = _make_worker() + meta = _remote_meta(worker, 0, 0, 16, pp_size=2) + + worker.add_remote_agent( + meta, + remote_tp_rank=0, + remote_tp_size=1, + remote_pp_rank=0, + remote_pp_size=2, + ) + worker._validate_remote_agent_handshake(meta, 0, 2, 1) + + +def test_add_remote_agent_prepares_dst_handles_for_each_pp_shard(): + worker = _make_worker() + + _add_two_remote_shards(worker) + + assert set(worker.dst_xfer_side_handles[REMOTE_ENGINE_ID]) == { + (0, 0), + (1, 0), + } + + +def test_read_blocks_for_req_appends_one_transfer_per_pp_shard_and_tp_target(): + worker = _make_worker() + _add_two_remote_shards(worker) + req_meta = ReqMeta( + local_block_ids=([0, 1],), + local_physical_block_ids=([0, 1],), + tp_size=1, + pp_size=2, + remote=RemoteMeta( + block_ids=([0, 1],), + host="localhost", + port=1234, + engine_id=REMOTE_ENGINE_ID, + request_id="prefill-req", + ), + ) + + worker._read_blocks_for_req("decode-req", req_meta) + + assert len(worker._recving_transfers["decode-req"]) == 2 + + +def test_pp_rank_one_descriptor_ids_are_shard_local(): + worker = _make_worker() + _add_two_remote_shards(worker) + + remote_desc_ids = worker._get_block_descs_ids_for_shard( + REMOTE_ENGINE_ID, 1, "remote", ([0],) + ) + local_desc_ids = worker._get_block_descs_ids_for_shard( + REMOTE_ENGINE_ID, 1, "local", ([0],) + ) + + assert remote_desc_ids[0] == 0 + assert local_desc_ids[0] == 0 diff --git a/tests/v1/kv_connector/nixl_integration/test_handshake_aggregation.py b/tests/v1/kv_connector/nixl_integration/test_handshake_aggregation.py new file mode 100644 index 000000000000..1e4380a66a37 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/test_handshake_aggregation.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import threading +from collections.abc import Callable +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, patch + +import msgspec +import pytest +import zmq + +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( + GET_META_MSG, + NixlHandshakePayload, + RemoteMeta, + ReqMeta, +) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import ( + NixlConnectorScheduler, +) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( + NixlConnectorWorker, + ReadSpec, +) + + +class _InlineThread: + def __init__( + self, + *, + target: Callable[..., Any], + args: tuple[Any, ...], + **_: Any, + ) -> None: + self._target = target + self._args = args + + def start(self) -> None: + self._target(*self._args) + + +class _FakeZmqContext: + def __init__(self, sock: "_FakeHandshakeSocket") -> None: + self._sock = sock + + def __enter__(self) -> "_FakeHandshakeSocket": + return self._sock + + def __exit__(self, *args: Any) -> None: + return None + + +class _FakeHandshakeSocket: + def __init__( + self, + request_msg: bytes, + *, + stop_event: threading.Event | None = None, + ) -> None: + self._request_msg = request_msg + self._stop_event = stop_event + self._recv_count = 0 + self.sent_multipart: list[tuple[bytes, bytes, bytes]] = [] + + def setsockopt(self, *_: Any) -> None: + return None + + def recv_multipart(self) -> tuple[bytes, bytes, bytes]: + if self._recv_count == 0: + self._recv_count += 1 + return (b"identity", b"", self._request_msg) + if self._stop_event is not None: + self._stop_event.set() + raise zmq.Again() + + def send_multipart(self, parts: tuple[bytes, bytes, bytes]) -> None: + self.sent_multipart.append(parts) + + +def test_engine_merge_preserves_pp_and_tp_keys(): + metadata_a = object() + metadata_b = object() + metadata_c = object() + worker_dicts = [ + {(0, 0): metadata_a}, + {(1, 0): metadata_b}, + {(0, 1): metadata_c}, + ] + + content: dict[tuple[int, int], object] = {} + for worker_dict in worker_dicts: + content.update(worker_dict) + + assert content == { + (0, 0): metadata_a, + (1, 0): metadata_b, + (0, 1): metadata_c, + } + + +def test_scheduler_listener_serves_three_tuple_key(): + scheduler = NixlConnectorScheduler.__new__(NixlConnectorScheduler) + scheduler._nixl_handshake_listener_t = None + scheduler._stop_event = threading.Event() + scheduler.side_channel_host = "localhost" + scheduler.side_channel_port = 1234 + + payload = NixlHandshakePayload( + compatibility_hash="hash", + agent_metadata_bytes=b"agent", + ) + request = msgspec.msgpack.encode((GET_META_MSG, 1, 0)) + sock = _FakeHandshakeSocket(request, stop_event=scheduler._stop_event) + + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler.zmq_ctx", + return_value=_FakeZmqContext(sock), + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler." + "threading.Thread", + _InlineThread, + ), + ): + scheduler.set_xfer_handshake_metadata({(1, 0): payload}) + + assert len(sock.sent_multipart) == 1 + identity, delimiter, encoded_payload = sock.sent_multipart[0] + assert identity == b"identity" + assert delimiter == b"" + decoded_payload = msgspec.msgpack.decode(encoded_payload, type=NixlHandshakePayload) + assert decoded_payload == payload + + +def test_ensure_handshake_treats_partial_pp_state_as_inflight(): + worker = NixlConnectorWorker.__new__(NixlConnectorWorker) + future = MagicMock() + remote_engine_id = "remote-engine" + worker._handshake_lock = threading.RLock() + worker._handshake_futures = {remote_engine_id: future} + worker._remote_agents = {remote_engine_id: {(0, 0): "agent-0-0"}} + worker._pp_layer_map = {} + + assert worker._ensure_handshake(remote_engine_id, "localhost", 1234, 1, 2) is future + + +def test_handshake_complete_requires_pp_layer_map(): + worker = NixlConnectorWorker.__new__(NixlConnectorWorker) + remote_engine_id = "remote-engine" + worker._handshake_futures = {} + worker._remote_agents = {remote_engine_id: {(0, 0): "agent-0-0"}} + worker._pp_layer_map = {} + + assert not worker._handshake_complete(remote_engine_id, 2) + + worker._pp_layer_map[remote_engine_id] = SimpleNamespace(pp_size=2) + + assert worker._handshake_complete(remote_engine_id, 2) + + +@pytest.mark.parametrize("pp_size", [1, 4]) +def test_background_nixl_handshake_submits_remote_pp_size(pp_size: int): + worker = NixlConnectorWorker.__new__(NixlConnectorWorker) + worker._handshake_futures = {} + worker._handshake_initiation_executor = MagicMock() + future = MagicMock() + worker._handshake_initiation_executor.submit.return_value = future + worker._handshake_lock = threading.Lock() + worker._remote_agents = {} + worker._ready_requests = MagicMock() + worker._log_failure = MagicMock() + worker._recving_transfers = {} + worker.src_xfer_handles_by_remote = {} + worker.src_xfer_handles_by_shard_tp_ratio = {} + worker.dst_xfer_side_handles = {} + worker._registered_descs = [] + + remote_engine_id = "remote-engine" + meta = ReqMeta( + local_block_ids=([0],), + local_physical_block_ids=([0],), + tp_size=2, + pp_size=pp_size, + remote=RemoteMeta( + block_ids=([1],), + host="localhost", + port=1234, + engine_id=remote_engine_id, + request_id="remote-request", + ), + ) + + worker._background_nixl_handshake("request", remote_engine_id, meta) + + worker._handshake_initiation_executor.submit.assert_called_once_with( + worker._nixl_handshake, + "localhost", + 1234, + 2, + pp_size, + remote_engine_id, + ) + assert future.add_done_callback.call_count == 2 + + +def test_hma_pp_assertion_guard_in_read_blocks() -> None: + """NIXL PR1 must reject HMA × PP combinations with AssertionError. + + This guard is the PR1↔PR2 split point. When NIXL PR2 lands per-layer-name + HMA × PP routing, the ``assert not self._is_hma_required`` checks inside + ``_read_blocks`` and friends will be lifted. Until then, configuring NIXL + with HMA enabled and a heterogeneous block-size remote (the path that + co-occurs under ``pp_size > 1`` with multi-group KV caches) must fail loud. + """ + import numpy as np + + worker = NixlConnectorWorker.__new__(NixlConnectorWorker) + worker._is_hma_required = True + worker.world_size = 1 + worker.block_size = 16 + worker._remote_agents = {"remote-engine": {(0, 0): "agent-0-0"}} + + transfer_topo = MagicMock() + transfer_topo.get_engine_info.return_value = SimpleNamespace( + remote_block_size=8, + remote_physical_blocks_per_logical=1, + ) + transfer_topo.block_size_ratio.return_value = 2 + worker.transfer_topo = transfer_topo + worker.get_mapped_blocks = MagicMock(return_value=np.asarray([0, 1, 2, 3])) + + spec = ReadSpec(remote_rank=0, local_block_ids=[[0]], remote_block_ids=[[1]]) + with pytest.raises(AssertionError): + worker._read_blocks( + read_spec=spec, + request_id="req", + dst_engine_id="remote-engine", + remote_request_id="rreq", + remote_pp_rank=0, + local_xfer_side_handle=0, + remote_xfer_side_handle=0, + ) diff --git a/tests/v1/kv_connector/nixl_integration/test_pp_layer_map.py b/tests/v1/kv_connector/nixl_integration/test_pp_layer_map.py new file mode 100644 index 000000000000..da874f3e7c41 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/test_pp_layer_map.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import pytest + +from vllm.distributed.kv_transfer.kv_connector.v1.nixl import ( + NixlAgentMetadata, + NixlConnectorMetadata, + PPLayerMap, +) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( + NIXL_CONNECTOR_VERSION, + compute_nixl_compatibility_hash, +) + + +def _meta( + pp_rank: int, + start_layer: int, + end_layer: int, + *, + pp_size: int = 4, + registered_layer_indices: list[int] | None = None, + registered_layer_names: list[str] | None = None, +) -> NixlAgentMetadata: + if registered_layer_indices is None: + registered_layer_indices = list(range(start_layer, end_layer)) + if registered_layer_names is None: + registered_layer_names = [ + f"model.layers.{idx}.self_attn" for idx in registered_layer_indices + ] + return NixlAgentMetadata( + engine_id="engine", + agent_metadata=b"agent", + kv_caches_base_addr=list(range(len(registered_layer_indices))), + device_id=pp_rank, + num_blocks=1, + block_lens=[1024] * len(registered_layer_indices), + kv_cache_layout="HND", + block_size=16, + ssm_sizes=(0, 0), + attn_backend_name="FLASH_ATTN", + physical_blocks_per_logical_kv_block=1, + pp_rank=pp_rank, + pp_size=pp_size, + start_layer=start_layer, + end_layer=end_layer, + registered_layer_indices=registered_layer_indices, + registered_layer_names=registered_layer_names, + ) + + +def _metas(boundaries: list[tuple[int, int]]) -> list[NixlAgentMetadata]: + pp_size = len(boundaries) + return [ + _meta(rank, start, end, pp_size=pp_size) + for rank, (start, end) in enumerate(boundaries) + ] + + +class _FakeAttentionBackend: + @staticmethod + def get_name() -> str: + return "FAKE_ATTN" + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, int, int, int, int]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + +def _fake_vllm_config(pipeline_parallel_size: int = 1) -> SimpleNamespace: + model_config = SimpleNamespace( + model="fake-model", + dtype="float16", + get_total_num_kv_heads=lambda: 8, + get_head_size=lambda: 16, + get_total_num_hidden_layers=lambda: 32, + ) + return SimpleNamespace( + model_config=model_config, + cache_config=SimpleNamespace(cache_dtype="auto", block_size=16), + scheduler_config=SimpleNamespace(disable_hybrid_kv_cache_manager=False), + parallel_config=SimpleNamespace( + pipeline_parallel_size=pipeline_parallel_size, + tensor_parallel_size=1, + ), + ) + + +def test_nixl_connector_version_is_bumped_to_v5(): + assert NIXL_CONNECTOR_VERSION == 5 + + +def test_pp_layer_map_round_trip_queries(): + layer_map = PPLayerMap.from_metadata_shards( + _metas([(0, 8), (8, 16), (16, 24), (24, 32)]), + total_num_hidden_layers=32, + ) + + assert layer_map.pp_size == 4 + assert layer_map.boundaries == ((0, 8), (8, 16), (16, 24), (24, 32)) + assert layer_map.registered_layer_indices == tuple( + tuple(range(start, end)) for start, end in ((0, 8), (8, 16), (16, 24), (24, 32)) + ) + assert layer_map.total_num_hidden_layers == 32 + + +@pytest.mark.parametrize( + "boundaries,total", + [ + ([(0, 8), (10, 16), (16, 24), (24, 32)], 32), + ([(0, 8), (8, 15), (16, 24), (24, 32)], 32), + ([(0, 8), (8, 16), (16, 24), (24, 30)], 32), + ], +) +def test_pp_layer_map_rejects_non_contiguous_or_incomplete_coverage(boundaries, total): + with pytest.raises(ValueError, match="boundaries must"): + PPLayerMap.from_metadata_shards( + _metas(boundaries), total_num_hidden_layers=total + ) + + +def test_pp_layer_map_collapses_duplicate_tp_records(): + metas = _metas([(0, 8), (8, 16), (16, 24), (24, 32)]) + metas.append(_meta(1, 8, 16, pp_size=4)) + + layer_map = PPLayerMap.from_metadata_shards(metas, total_num_hidden_layers=32) + + assert layer_map.registered_layer_indices[1] == tuple(range(8, 16)) + + +def test_pp_layer_map_rejects_conflicting_duplicate_tp_records(): + metas = _metas([(0, 8), (8, 16), (16, 24), (24, 32)]) + metas.append(_meta(1, 8, 16, pp_size=4, registered_layer_indices=[8, 8])) + + with pytest.raises(ValueError, match="conflicting metadata"): + PPLayerMap.from_metadata_shards(metas, total_num_hidden_layers=32) + + +def test_compatibility_hash_ignores_pipeline_parallel_size(): + assert compute_nixl_compatibility_hash( + _fake_vllm_config(pipeline_parallel_size=1), "FLASH_ATTN", False + ) == compute_nixl_compatibility_hash( + _fake_vllm_config(pipeline_parallel_size=4), "FLASH_ATTN", False + ) + + +def test_req_meta_reads_pp_size_and_defaults_to_one(): + metadata = NixlConnectorMetadata() + params = { + "remote_block_ids": ([0],), + "remote_engine_id": "engine", + "remote_request_id": "remote-req", + "remote_host": "localhost", + "remote_port": 1234, + "tp_size": 2, + "pp_size": 4, + } + + metadata.add_new_req_to_recv("req", ([0],), params) + assert metadata.reqs_to_recv["req"].pp_size == 4 + + params.pop("pp_size") + metadata.add_new_req_to_recv("req-default", ([0],), params) + assert metadata.reqs_to_recv["req-default"].pp_size == 1 diff --git a/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py b/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py index dc76d61178d8..1eb928999973 100644 --- a/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py +++ b/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py @@ -105,6 +105,7 @@ def _make_connector_with_fake_worker( host="localhost", port=1234, remote_tp_size=1, + remote_pp_size=1, expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, ) worker._remote_agents[FakeNixlConnectorWorker.REMOTE_ENGINE_ID] = remote_agents diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index f07a8352e735..e3a62e2e6ea9 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -9,6 +9,7 @@ import time import uuid from collections import defaultdict +from types import SimpleNamespace from typing import Any, cast from unittest.mock import MagicMock, patch @@ -20,6 +21,7 @@ from vllm import LLM from vllm.config import KVTransferConfig, set_current_vllm_config from vllm.distributed.kv_transfer.kv_connector.utils import ( + EngineTransferInfo, KVOutputAggregator, TransferTopology, get_current_attn_backend, @@ -373,7 +375,11 @@ def test_kv_transfer_handshake(dist_init): # metadata. kv_cache_groups = [ KVCacheGroupSpec( - ["layer0", "layer1", "layer2"], + [ + "model.layers.0.self_attn", + "model.layers.1.self_attn", + "model.layers.2.self_attn", + ], FullAttentionSpec( block_size=16, num_kv_heads=4, @@ -400,9 +406,9 @@ def test_kv_transfer_handshake(dist_init): shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) kv_caches = { - "layer0": shared_tensor, - "layer1": unique_tensor, - "layer2": shared_tensor, + "model.layers.0.self_attn": shared_tensor, + "model.layers.1.self_attn": unique_tensor, + "model.layers.2.self_attn": shared_tensor, } prefill_connector.register_kv_caches(kv_caches) @@ -414,11 +420,12 @@ def test_kv_transfer_handshake(dist_init): decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes) - # The scheduler connector expects metadata to be in - # dict[int, KVConnectorHandshakeMetadata], where the first key is - # the dp_rank, the second key is the tp_rank. + # The scheduler connector expects metadata keyed by (pp_rank, tp_rank). + # Engine core dispatches via the PP-aware entry point, which is what + # starts the NIXL handshake listener; the legacy non-PP-aware setter + # does not, so use the PP-aware one here to match production. scheduler_connector = scheduler.get_kv_connector() - scheduler_connector.set_xfer_handshake_metadata({0: metadata}) + scheduler_connector.set_xfer_handshake_metadata_pp_aware({(0, 0): metadata}) # Simulate a request that finishes prefill, which returns # corresponding NixlConnectorMetadata for decode instance. @@ -458,6 +465,7 @@ def test_kv_transfer_handshake(dist_init): kv_connector_metadata["remote_host"], kv_connector_metadata["remote_port"], kv_connector_metadata["tp_size"], + kv_connector_metadata.get("pp_size", 1), kv_connector_metadata["remote_engine_id"], ) @@ -486,8 +494,6 @@ def __init__( super().__init__(*args, kv_cache_config=kv_cache_config, **kwargs) self._hand_shake_latency = hand_shake_latency self.kv_cache_layout = kv_cache_layout - # Mock register_kv_caches attribute needed for tests that do not call it. - self.src_xfer_handles_by_block_size = {self.block_size: 1} test_shape = self.attn_backends[0].get_kv_cache_shape( num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 ) @@ -508,10 +514,16 @@ def __init__( ) def _nixl_handshake( - self, host: str, port: int, remote_tp_size: int, expected_engine_id: str - ) -> dict[int, str]: + self, + host: str, + port: int, + remote_tp_size: int, + remote_pp_size: int, + expected_engine_id: str, + ) -> dict[tuple[int, int], str]: # Mimic slow _nixl_handshake, as well as bypass zmq communication. time.sleep(self._hand_shake_latency) + assert remote_pp_size == 1 # These should've been done in register_kv_caches(), called by # gpu_model_runner. Here we just hardcode some dummy values. slot_size_bytes = 4096 @@ -519,6 +531,10 @@ def _nixl_handshake( self.block_len_per_layer = [slot_size_bytes * self.block_size] self.num_blocks = 1 self.dst_num_blocks[self.engine_id] = self.num_blocks + self.local_registered_layer_indices = [0] + self.local_seen_layer_names = ["model.layers.0.self_attn"] + self._local_kv_cache_key = (0, self.tp_rank) + self.kv_caches_base_addr[self.engine_id][self._local_kv_cache_key] = [0] assert expected_engine_id == self.REMOTE_ENGINE_ID @@ -539,7 +555,7 @@ def _nixl_handshake( # When remote tp_size > local tp_size, handshake with multiple # remote ranks. num_handshakes = 1 if tp_ratio > 0 else -tp_ratio - remote_agents: dict[int, str] = {} + remote_agents: dict[tuple[int, int], str] = {} for remote_tp_rank in range(num_handshakes): remote_agent_name = self.add_remote_agent( NixlAgentMetadata( @@ -556,11 +572,17 @@ def _nixl_handshake( ssm_sizes=(0, 0), attn_backend_name=self.backend_name, physical_blocks_per_logical_kv_block=1, + pp_rank=0, + pp_size=1, + start_layer=0, + end_layer=1, + registered_layer_indices=[0], + registered_layer_names=["model.layers.0.self_attn"], ), remote_tp_rank=remote_tp_rank, remote_tp_size=remote_tp_size, ) - remote_agents[remote_tp_rank] = remote_agent_name + remote_agents[(0, remote_tp_rank)] = remote_agent_name return remote_agents @@ -601,7 +623,7 @@ def test_multi_xfer_one_engine( worker.nixl_wrapper.set_cycles_before_xfer_done(3) # simulate handshake worker.dst_xfer_side_handles = { - FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1} + FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {(0, 0): 1} } worker.kv_cache_layout = "HND" num_xfers = 4 @@ -754,29 +776,34 @@ def test_prefill_tp_size_greater_than_decode_tp_size( worker.block_len_per_layer = [4096 * worker.block_size] worker.num_blocks = 1 worker.dst_num_blocks[worker.engine_id] = worker.num_blocks - worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)] - worker.num_descs = len(worker.src_blocks_data) + worker.num_descs = 1 + worker.local_registered_layer_indices = [0] + worker.local_seen_layer_names = ["model.layers.0.self_attn"] + worker._local_kv_cache_key = (0, worker.tp_rank) + worker.kv_caches_base_addr[worker.engine_id][worker._local_kv_cache_key] = [0] def check_handshake(remote_tp_size: int): tp_ratio = remote_tp_size // local_tp_size - assert set(remote_agents.keys()) == set(range(tp_ratio)) + assert set(remote_agents.keys()) == {(0, rank) for rank in range(tp_ratio)} remote_engine_id = worker.REMOTE_ENGINE_ID - remote_info = worker.transfer_topo.get_engine_info(remote_engine_id) + remote_info = worker.transfer_topo.get_engine_info(remote_engine_id, 0) assert remote_info.remote_tp_size == remote_tp_size assert -tp_ratio == worker.transfer_topo.tp_ratio(remote_tp_size) - # ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks - assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio - assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio + # Each shard registers a list of tp_ratio handles. + split_key = (remote_engine_id, 0, -tp_ratio) + assert split_key in worker.src_xfer_handles_by_shard_tp_ratio + assert len(worker.src_xfer_handles_by_shard_tp_ratio[split_key]) == tp_ratio assert remote_engine_id in worker.dst_xfer_side_handles assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set( - range(tp_ratio) + (0, rank) for rank in range(tp_ratio) ) remote_agents = worker._nixl_handshake( host="localhost", port=1234, remote_tp_size=4, + remote_pp_size=1, expected_engine_id=worker.REMOTE_ENGINE_ID, ) check_handshake(4) @@ -789,6 +816,7 @@ def check_handshake(remote_tp_size: int): host="localhost", port=1234, remote_tp_size=6, + remote_pp_size=1, expected_engine_id=worker.REMOTE_ENGINE_ID, ) check_handshake(6) @@ -917,9 +945,6 @@ def test_concurrent_load_kv( connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id ) - # Register (mocked) local xfer handler - # worker = connector.connector_worker - # worker.src_xfer_handles_by_block_size = {worker.block_size: 1} metadata = NixlConnectorMetadata() total_reqs = 5 for i in range(total_reqs): @@ -994,6 +1019,12 @@ def test_handshake_fails_on_kv_cache_layout_mismatch( worker.block_len_per_layer = [4096 * worker.block_size] worker.num_blocks = 1 worker.dst_num_blocks[worker.engine_id] = worker.num_blocks + worker.local_registered_layer_indices = [0] + worker.local_seen_layer_names = ["model.layers.0.self_attn"] + worker._local_kv_cache_key = (0, worker.tp_rank) + worker.kv_caches_base_addr[worker.engine_id][worker._local_kv_cache_key] = [ + 0 + ] # Metadata with different kv_cache_layout than local worker mismatched_layout = "HND" if worker.kv_cache_layout != "HND" else "NHD" @@ -1009,6 +1040,12 @@ def test_handshake_fails_on_kv_cache_layout_mismatch( ssm_sizes=(0, 0), attn_backend_name=worker.backend_name, physical_blocks_per_logical_kv_block=1, + pp_rank=0, + pp_size=1, + start_layer=0, + end_layer=1, + registered_layer_indices=[0], + registered_layer_names=["model.layers.0.self_attn"], ) with pytest.raises(RuntimeError): @@ -1052,6 +1089,12 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( worker.block_len_per_layer = [2048 * worker.block_size] worker.num_blocks = 1 worker.dst_num_blocks[worker.engine_id] = worker.num_blocks + worker.local_registered_layer_indices = [0] + worker.local_seen_layer_names = ["model.layers.0.self_attn"] + worker._local_kv_cache_key = (0, worker.tp_rank) + worker.kv_caches_base_addr[worker.engine_id][worker._local_kv_cache_key] = [ + 0 + ] # Metadata with different kv_cache_layout than local worker meta = NixlAgentMetadata( @@ -1067,6 +1110,12 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( ssm_sizes=(0, 0), attn_backend_name=worker.backend_name, physical_blocks_per_logical_kv_block=1, + pp_rank=0, + pp_size=1, + start_layer=0, + end_layer=1, + registered_layer_indices=[0], + registered_layer_names=["model.layers.0.self_attn"], ) # We don't check layout for homogeneous TP and MLA for now, as the @@ -1587,18 +1636,27 @@ def test_register_kv_caches( kv_cache_tensors=[ KVCacheTensor( size=kv_cache_spec.page_size_bytes * num_blocks, - shared_by=["all-layers"], + shared_by=["model.layers.0.self_attn"], ) for _ in range(num_layers) ], - kv_cache_groups=[KVCacheGroupSpec(["all-layers"], kv_cache_spec)], + kv_cache_groups=[ + KVCacheGroupSpec(["model.layers.0.self_attn"], kv_cache_spec) + ], ) else: kv_cache_config = KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec(["layer0", "layer1", "layer2"], kv_cache_spec) + KVCacheGroupSpec( + [ + "model.layers.0.self_attn", + "model.layers.1.self_attn", + "model.layers.2.self_attn", + ], + kv_cache_spec, + ) ], ) # Create connector @@ -1668,7 +1726,7 @@ def test_register_kv_caches( expected_blocks_count = num_blocks * (2 if virtually_split else 1) - kv_caches = {"all-layers": cross_layers_kv_cache} + kv_caches = {"model.layers.0.self_attn": cross_layers_kv_cache} else: # Create test kv cache tensors using proper backend shape kv_cache_shape = backend_cls.get_kv_cache_shape( @@ -1680,9 +1738,9 @@ def test_register_kv_caches( shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) kv_caches = { - "layer0": shared_tensor, - "layer1": unique_tensor, - "layer2": shared_tensor, + "model.layers.0.self_attn": shared_tensor, + "model.layers.1.self_attn": unique_tensor, + "model.layers.2.self_attn": shared_tensor, } # Store tensor info for validation @@ -1726,6 +1784,12 @@ def test_register_kv_caches( f"got {base_addr}" ) + # The local NIXL transfer-descriptor dlist is now registered lazily (on + # the first remote handshake) rather than eagerly in register_kv_caches; + # trigger that registration so the get_xfer_descs assertions below see + # the local block layout. + connector.connector_worker.register_local_xfer_handler(block_size) + # Verify get_xfer_descs was called with blocks_data assert mock_wrapper_instance.get_xfer_descs.called blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] @@ -1853,11 +1917,11 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init): ): worker._recving_transfers = {"req1": [123]} # Mock register_kv_cache which registers local handle - worker.src_xfer_handles_by_block_size = {worker.block_size: 455} + worker.src_xfer_handles_by_remote = {("engine1", 0, worker.block_size): 455} # P TP = 2 * D TP case, we should register 2 local handles - worker.src_xfer_handles_by_tp_ratio = {-2: [456, 457]} - worker.dst_xfer_side_handles = {"engine1": {0: 789}} - worker._remote_agents = {"engine1": {0: "agent1"}} + worker.src_xfer_handles_by_shard_tp_ratio = {("engine1", 0, -2): [456, 457]} + worker.dst_xfer_side_handles = {"engine1": {(0, 0): 789}} + worker._remote_agents = {"engine1": {(0, 0): "agent1"}} worker._registered_descs = ["desc1", "desc2"] mock_listener.is_alive.return_value = False @@ -2507,6 +2571,12 @@ def test_compatibility_hash_validation( ssm_sizes=(0, 0), attn_backend_name=decode_worker.backend_name, physical_blocks_per_logical_kv_block=1, + pp_rank=0, + pp_size=1, + start_layer=0, + end_layer=1, + registered_layer_indices=[0], + registered_layer_names=["model.layers.0.self_attn"], ) handshake_payload = NixlHandshakePayload( compatibility_hash=remote_hash, @@ -2531,6 +2601,7 @@ def test_compatibility_hash_validation( host="localhost", port=1234, remote_tp_size=1, + remote_pp_size=1, expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, ) else: @@ -2538,11 +2609,12 @@ def test_compatibility_hash_validation( host="localhost", port=1234, remote_tp_size=1, + remote_pp_size=1, expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, ) # Verify handshake returned agent mapping assert isinstance(result, dict) - assert len(result) == 1 + assert set(result) == {(0, 0)} @pytest.mark.parametrize( @@ -2630,6 +2702,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) host="localhost", port=1234, remote_tp_size=1, + remote_pp_size=1, expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, ) @@ -2676,19 +2749,28 @@ def test_mla_broadcast_notif_uses_remote_request_id( # / `dst_xfer_side_handles` to be keyed by remote rank. remote_engine_id = "remote_engine" worker.transfer_topo.register_remote_engine( - remote_engine_id=remote_engine_id, - remote_tp_size=prefill_tp_size, - remote_block_size=worker.block_size, - remote_block_len=worker.block_size * 4096, - remote_physical_blocks_per_logical=1, - local_block_len=worker.block_size * 4096, + remote_engine_id, + EngineTransferInfo( + remote_tp_size=prefill_tp_size, + remote_block_len=worker.block_size * 4096, + remote_block_size=worker.block_size, + remote_physical_blocks_per_logical=1, + remote_pp_rank=0, + start_layer=0, + end_layer=1, + ), ) worker._remote_agents[remote_engine_id] = { - rank: f"agent_p{rank}" for rank in range(prefill_tp_size) + (0, rank): f"agent_p{rank}" for rank in range(prefill_tp_size) } worker.dst_xfer_side_handles = { - remote_engine_id: {rank: 100 + rank for rank in range(prefill_tp_size)} + remote_engine_id: {(0, rank): 100 + rank for rank in range(prefill_tp_size)} } + worker._pp_layer_map[remote_engine_id] = SimpleNamespace(pp_size=1) + worker.tp_mappings[(remote_engine_id, 0)] = SimpleNamespace( + all_source_ranks=(0,), + source_ranks_per_group=((0,),), + ) # Sanity: D TP=1, P TP=4 => tp_ratio = -4 (P > D). assert worker.transfer_topo.tp_ratio(prefill_tp_size) == -prefill_tp_size @@ -2735,7 +2817,7 @@ def test_mla_broadcast_notif_uses_remote_request_id( # Broadcast goes to ranks {1, 2, 3} only, never to the read target. expected_recipients = { - worker._remote_agents[remote_engine_id][r] + worker._remote_agents[remote_engine_id][(0, r)] for r in range(1, prefill_tp_size) } assert {agent for agent, _ in send_notif_calls} == expected_recipients diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 6d4e6565e373..93f7a5f9d849 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -3,6 +3,7 @@ """Unit tests for NixlConnectorScheduler with HMA and Mamba N-1 prefill.""" import gc +from types import SimpleNamespace from unittest.mock import patch import pytest @@ -179,11 +180,12 @@ def test_read_blocks_for_req_expands_remote_ids( remote_info.remote_physical_blocks_per_logical = remote_physical_per_logical worker.transfer_topo.get_engine_info.return_value = remote_info worker.use_mla = False + worker._pp_layer_map = {remote_engine_id: SimpleNamespace(pp_size=1)} mock_plan = MagicMock(spec=TPMapping) mock_plan.all_source_ranks = () mock_plan.source_ranks_per_group = () - worker.tp_mappings = {remote_engine_id: mock_plan} + worker.tp_mappings = {(remote_engine_id, 0): mock_plan} metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( @@ -200,10 +202,31 @@ def test_read_blocks_for_req_expands_remote_ids( ) meta = metadata.reqs_to_recv["test-req"] - worker._read_blocks_for_req("test-req", meta) - assert meta.remote.block_ids == expected_remote_block_ids, ( - f"Expected {expected_remote_block_ids}, got {meta.remote.block_ids}" + # pr1 keeps meta.remote.block_ids logical (reused on retry/failure paths) + # and applies the per-PP-stage kernel expansion to a local copy used for + # the reads. Spy on the expansion helper to capture the kernel block IDs. + real_expand = NixlConnectorWorker._logical_to_remote_kernel_block_ids + captured = [] + + def _capture_expand(self, logical_ids, phys_per_logical): + result = real_expand(self, logical_ids, phys_per_logical) + captured.append(result) + return result + + with patch.object( + NixlConnectorWorker, + "_logical_to_remote_kernel_block_ids", + _capture_expand, + ): + worker._read_blocks_for_req("test-req", meta) + + # Remote IDs stay logical (unchanged) for retry/failure reuse. + assert meta.remote.block_ids == remote_block_ids + # The expansion still produces the expected kernel block IDs. + assert captured, "expansion helper was not called" + assert captured[0] == expected_remote_block_ids, ( + f"Expected {expected_remote_block_ids}, got {captured[0]}" ) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 1b892849d909..9647aaf4438b 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -447,7 +447,7 @@ def make_kv_cache_config( ) -> KVCacheConfig: kv_cache_groups = [ KVCacheGroupSpec( - ["layer0", "layer2"], + ["model.layers.0.self_attn", "model.layers.2.self_attn"], FullAttentionSpec( block_size=block_size, num_kv_heads=4, @@ -459,7 +459,7 @@ def make_kv_cache_config( if swa_enabled: kv_cache_groups.append( KVCacheGroupSpec( - ["layer1", "layer3"], + ["model.layers.1.self_attn", "model.layers.3.self_attn"], SlidingWindowSpec( block_size=block_size, num_kv_heads=4, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/__init__.py index ed5c892fb9df..473f6f60a0ac 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/__init__.py @@ -10,6 +10,9 @@ NixlConnectorMetadata, NixlHandshakePayload, ) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.pp_layer_map import ( + PPLayerMap, +) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import ( NixlConnectorScheduler, ) @@ -28,4 +31,5 @@ "NixlConnectorWorker", "NixlHandshakePayload", "NixlKVConnectorStats", + "PPLayerMap", ] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py index 187322b4ae4e..669dd57bf3f0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py @@ -181,11 +181,15 @@ def request_finished_all_groups( assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) - def set_xfer_handshake_metadata( - self, metadata: dict[int, KVConnectorHandshakeMetadata] + def set_xfer_handshake_metadata_pp_aware( + self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata] ) -> None: """ - Set the KV connector handshake metadata for this connector. + Set the PP-aware KV connector handshake metadata for this connector. + + Engine core hands NIXL the full ``{(pp_rank, tp_rank): metadata}`` + dict. Non-PP unwrap for legacy connectors happens in engine core + before the dispatch, so NIXL never sees the int-keyed shape. Args: metadata (dict): the handshake metadata to set. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py index b9e3436f5019..f0bf59b54d6d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py @@ -34,8 +34,13 @@ # 2: Add remote_request_id to kv_transfer_params # 3: Add physical_blocks_per_logical_kv_block to NixlAgentMetadata # 4: Add KV block lease renewal through heartbeats +# 5: Add pipeline-parallel producer metadata (pp_rank, pp_size, +# start_layer, end_layer, registered_layer_indices, +# registered_layer_names) and per-request pp_size. NIXL regions are +# advertised per layer-name so HMA pool composition may differ across PP +# producer shards and the decode consumer. # -NIXL_CONNECTOR_VERSION: int = 4 +NIXL_CONNECTOR_VERSION: int = 5 @dataclass @@ -51,6 +56,12 @@ class NixlAgentMetadata: ssm_sizes: tuple[int, int] attn_backend_name: str physical_blocks_per_logical_kv_block: int + pp_rank: int + pp_size: int + start_layer: int + end_layer: int + registered_layer_indices: list[int] + registered_layer_names: list[str] @dataclass @@ -142,6 +153,7 @@ class HeartbeatInfo: host: str port: int tp_size: int + pp_size: int @dataclass @@ -159,6 +171,7 @@ class ReqMeta: # To be used when logical block size does not match the kernel block size local_physical_block_ids: BlockIds tp_size: int + pp_size: int = 1 remote: RemoteMeta | None = None @@ -182,6 +195,7 @@ def _add_new_req( local_physical_block_ids=local_block_ids, # P workers don't need to receive tp_size from proxy here. tp_size=kv_transfer_params.get("tp_size", 1), + pp_size=kv_transfer_params.get("pp_size", 1), ) def add_new_req_to_save( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/pp_layer_map.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/pp_layer_map.py new file mode 100644 index 000000000000..4793e14f146f --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/pp_layer_map.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Pipeline-parallel layer map helpers for NIXL metadata.""" + +from dataclasses import dataclass + +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( + NixlAgentMetadata, +) + + +@dataclass(frozen=True) +class PPLayerMap: + pp_size: int + boundaries: tuple[tuple[int, int], ...] + registered_layer_indices: tuple[tuple[int, ...], ...] + total_num_hidden_layers: int + + @classmethod + def from_metadata_shards( + cls, metas: list[NixlAgentMetadata], total_num_hidden_layers: int + ) -> "PPLayerMap": + assert metas, "cannot build PPLayerMap without metadata shards" + pp_size = metas[0].pp_size + + boundaries_by_rank: dict[int, tuple[int, int]] = {} + registered_by_rank: dict[int, tuple[int, ...]] = {} + for meta in metas: + boundary = (meta.start_layer, meta.end_layer) + registered_layers = tuple(meta.registered_layer_indices) + prior_boundary = boundaries_by_rank.get(meta.pp_rank) + if prior_boundary is not None: + if ( + prior_boundary != boundary + or registered_by_rank[meta.pp_rank] != registered_layers + ): + raise ValueError( + f"conflicting metadata shards for pp_rank {meta.pp_rank}" + ) + continue + boundaries_by_rank[meta.pp_rank] = boundary + registered_by_rank[meta.pp_rank] = registered_layers + + assert len(boundaries_by_rank) == pp_size, ( + "missing metadata shards for pp_rank(s): " + f"{sorted(set(range(pp_size)) - boundaries_by_rank.keys())}" + ) + + expected_start = 0 + for pp_rank in range(pp_size): + start, end = boundaries_by_rank[pp_rank] + if start != expected_start: + raise ValueError( + "PP layer boundaries must be contiguous; pp_rank " + f"{pp_rank} starts at {start}, expected {expected_start}" + ) + expected_start = end + if expected_start != total_num_hidden_layers: + raise ValueError( + "PP layer boundaries must cover all hidden layers; last end " + f"{expected_start}, expected {total_num_hidden_layers}" + ) + + return cls( + pp_size=pp_size, + boundaries=tuple(boundaries_by_rank[rank] for rank in range(pp_size)), + registered_layer_indices=tuple( + registered_by_rank[rank] for rank in range(pp_size) + ), + total_num_hidden_layers=total_num_hidden_layers, + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py index b2122ed0d30b..96f0e5c8b811 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py @@ -185,6 +185,7 @@ def on_new_request(self, request: "Request") -> None: host = params.get("remote_host") port = params.get("remote_port") tp_size = params.get("tp_size") + pp_size = params.get("pp_size", 1) if ( remote_engine_id is None or remote_request_id is None @@ -199,6 +200,7 @@ def on_new_request(self, request: "Request") -> None: host=host, port=port, tp_size=tp_size, + pp_size=pp_size, ) self._heartbeat_by_engine[remote_engine_id].req_ids.add(remote_request_id) self._heartbeat_req_engine[request.request_id] = ( @@ -244,7 +246,7 @@ def get_sw_clipped_blocks(self, block_ids: BlockIds) -> BlockIds: ) def set_xfer_handshake_metadata( - self, metadata: dict[int, KVConnectorHandshakeMetadata] + self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata] ) -> None: """ Set the KV connector handshake metadata for this connector. @@ -252,19 +254,20 @@ def set_xfer_handshake_metadata( Args: metadata (dict): the handshake metadata to set. """ - encoded_data: dict[int, bytes] = {} + encoded_data: dict[tuple[int, int], bytes] = {} encoder = msgspec.msgpack.Encoder() - for tp_rank, rank_metadata in metadata.items(): + for (pp_rank, tp_rank), rank_metadata in metadata.items(): if not isinstance(rank_metadata, NixlHandshakePayload): raise ValueError( "NixlConnectorScheduler expects NixlHandshakePayload for " "handshake metadata." ) - encoded_data[tp_rank] = encoder.encode(rank_metadata) + encoded_data[(pp_rank, tp_rank)] = encoder.encode(rank_metadata) logger.debug( - "Tp rank %d: encoded NixlHandshakePayload size: %s bytes", + "PP rank %d, TP rank %d: encoded NixlHandshakePayload size: %s bytes", + pp_rank, tp_rank, - str(len(encoded_data[tp_rank])), + str(len(encoded_data[(pp_rank, tp_rank)])), ) # Only start the listener when we have metadata to serve. @@ -287,7 +290,7 @@ def set_xfer_handshake_metadata( @staticmethod def _nixl_handshake_listener( - encoded_data: dict[int, Any], + encoded_data: dict[tuple[int, int], Any], ready_event: threading.Event, stop_event: threading.Event, host: str, @@ -310,15 +313,19 @@ def _nixl_handshake_listener( if stop_event.is_set(): break continue - # Decode the message which contains (GET_META_MSG, rank) - msg, target_tp_rank = msgspec.msgpack.decode(msg) + # Decode the message which contains + # (GET_META_MSG, pp_rank, tp_rank) + msg, target_pp_rank, target_tp_rank = msgspec.msgpack.decode(msg) logger.debug( - "Received message for tp rank %s", + "Received message for pp rank %s, tp rank %s", + target_pp_rank, target_tp_rank, ) if msg != GET_META_MSG: logger.warning("Connection listener got unexpected message %s", msg) - sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) + sock.send_multipart( + (identity, b"", encoded_data[(target_pp_rank, target_tp_rank)]) + ) def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int: """D-side only. Returns N-1 for Mamba models since the decoder @@ -670,5 +677,6 @@ def request_finished( remote_host=self.side_channel_host, remote_port=self.side_channel_port, tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + pp_size=self.vllm_config.parallel_config.pipeline_parallel_size, remote_num_tokens=remote_num_tokens, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 0d30d4a692ad..082ac707c5c9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -40,6 +40,9 @@ TransferHandle, compute_nixl_compatibility_hash, ) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.pp_layer_map import ( + PPLayerMap, +) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.stats import ( NixlKVConnectorStats, ) @@ -61,10 +64,12 @@ ) from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config from vllm.distributed.parallel_state import ( + get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger +from vllm.model_executor.models.utils import extract_layer_index from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path from vllm.v1.attention.backends.utils import get_kv_cache_layout @@ -235,6 +240,11 @@ def __init__( for group in kv_cache_config.kv_cache_groups for layer in group.layer_names } + self._layer_name_to_kv_group_index = { + layer: group_idx + for group_idx, group in enumerate(kv_cache_config.kv_cache_groups) + for layer in group.layer_names + } self.hma_group_size = len(kv_cache_config.kv_cache_tensors) # ---- Model state (derived from model config) ---- @@ -291,8 +301,10 @@ def __init__( ) self.nixl_wrapper = nixl_wrapper_cls(str(uuid.uuid4()), config) - # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. - self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) + # Map of engine_id -> {(pp_rank, tp_rank): agent_name}. + self._remote_agents: dict[EngineId, dict[tuple[int, int], str]] = defaultdict( + dict + ) # Metadata. self.engine_id: EngineId = engine_id @@ -356,23 +368,46 @@ def __init__( # Note: host xfer buffer ops when use_host_buffer is True self.copy_blocks: CopyBlocksOp | None = None - # Map of engine_id -> kv_caches_base_addr. For TP case, each local + self._local_kv_cache_key = (0, self.tp_rank) + self.device_id: int = 0 - # Current rank may pull from multiple remote TP workers. - # EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer - self.kv_caches_base_addr = defaultdict[EngineId, dict[int, list[int]]](dict) + # Map of engine_id -> kv_caches_base_addr. Under heterogeneous + # PP x TP the local rank may pull from multiple remote + # (pp_rank, tp_rank) shards, so we key by the producer shard tuple + # rather than just tp_rank. + # EngineId, dict[(pp_rank, tp_rank), list[int]] + # -> engine_id, (pp_rank, tp_rank), base_addr_for_layer + self.kv_caches_base_addr = defaultdict[ + EngineId, dict[tuple[int, int], list[int]] + ](dict) + self.local_seen_layer_names: list[str] = [] + self._remote_agent_metadata: dict[ + EngineId, dict[tuple[int, int], NixlAgentMetadata] + ] = defaultdict(dict) + self._pp_layer_map: dict[EngineId, PPLayerMap] = {} # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) self.num_regions = 0 - # nixl_prepped_dlist_handle. - self.src_xfer_handles_by_block_size: dict[int, int] = {} + # nixl_prepped_dlist_handle, keyed by remote shard + block size. + self.src_xfer_handles_by_remote: dict[tuple[EngineId, int, int], int] = {} + self.src_blocks_data_by_remote: dict[ + tuple[EngineId, int, int], list[tuple[int, int, int]] + ] = {} # Populated dynamically during handshake based on remote configuration. - # Keep track of regions at different tp_ratio values. tp_ratio->handles - self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {} - # Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}. - self.dst_xfer_side_handles = defaultdict[EngineId, dict[int, int]](dict) + self.src_xfer_handles_by_shard_tp_ratio: dict[ + tuple[EngineId, int, int], list[int] + ] = {} + # Map of engine_id -> {(pp_rank, tp_rank): nixl_prepped_dlist_handle}. + self.dst_xfer_side_handles = defaultdict[EngineId, dict[tuple[int, int], int]]( + dict + ) + # Per-shard descriptor layout: (num_blocks, region_group_ids) + # keyed by (engine_id, remote_pp_rank, "local" | "remote"). + self._xfer_desc_layouts: dict[ + tuple[EngineId, int, str], tuple[int, tuple[int, ...]] + ] = {} # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. @@ -404,7 +439,7 @@ def __init__( thread_name_prefix="vllm-nixl-handshake-initiator", ) self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]() - self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {} + self._handshake_futures: dict[EngineId, Future[dict[tuple[int, int], str]]] = {} # Protects _handshake_futures and _remote_agents. self._handshake_lock = threading.RLock() @@ -445,7 +480,7 @@ def __init__( ) # Per-engine TP mappings. Generated during handshake. - self.tp_mappings: dict[EngineId, TPMapping] = {} + self.tp_mappings: dict[tuple[EngineId, int], TPMapping] = {} self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( "enforce_handshake_compat", True @@ -470,13 +505,91 @@ def _sync_block_size_with_kernel(self) -> None: self.block_size = kernel_block_size self.num_blocks *= self._physical_blocks_per_logical_kv_block + def _get_local_base_addresses(self) -> list[int]: + return self.kv_caches_base_addr[self.engine_id].get( + self._local_kv_cache_key, [] + ) + + def _local_region_indices_for_layer_names( + self, registered_layer_names: list[str] + ) -> list[int]: + local_names = self.local_seen_layer_names + positions_by_name: dict[str, list[int]] = defaultdict(list) + for local_idx, layer_name in enumerate(local_names): + positions_by_name[layer_name].append(local_idx) + + occurrences_by_name: dict[str, int] = defaultdict(int) + local_indices: list[int] = [] + for layer_name in registered_layer_names: + occurrence = occurrences_by_name[layer_name] + occurrences_by_name[layer_name] += 1 + matches = positions_by_name.get(layer_name, []) + if occurrence >= len(matches): + raise RuntimeError( + "NIXL handshake failed: producer registered layer " + f"{layer_name!r} occurrence {occurrence} has no matching " + f"local region. Local registered layers: {local_names}" + ) + local_indices.append(matches[occurrence]) + return local_indices + + def _region_group_ids_for_layer_names( + self, registered_layer_names: list[str] + ) -> tuple[int, ...]: + # Non-MLA split-K/V backends register K and V as separate regions + # that share the same KV-group id, hence the x2 duplication. + group_ids = tuple( + self._layer_name_to_kv_group_index[name] for name in registered_layer_names + ) + assert self.transfer_topo is not None + if self.transfer_topo.is_kv_layout_blocks_first: + return tuple(g for g in group_ids for _ in range(2)) + return group_ids + + def _try_update_pp_layer_map( + self, engine_id: EngineId, remote_pp_size: int + ) -> PPLayerMap | None: + metas = list(self._remote_agent_metadata[engine_id].values()) + if not metas: + return None + if len({meta.pp_rank for meta in metas}) < remote_pp_size: + return None + pp_map = self._build_pp_layer_map_from_metas(metas, remote_pp_size) + assert pp_map.pp_size == remote_pp_size + self._pp_layer_map[engine_id] = pp_map + return pp_map + + def _build_pp_layer_map_from_metas( + self, metas: list[NixlAgentMetadata], remote_pp_size: int + ) -> PPLayerMap: + total_layers = self.model_config.get_total_num_hidden_layers() + try: + return PPLayerMap.from_metadata_shards(metas, total_layers) + except ValueError: + if remote_pp_size != 1: + raise + meta = metas[0] + assert meta.pp_rank == 0 and meta.pp_size == 1 + assert all(0 <= g < total_layers for g in meta.registered_layer_indices) + logger.warning( + "Non-PP NIXL metadata did not advertise full model coverage; " + "using a single-shard compatibility layer map." + ) + return PPLayerMap( + pp_size=1, + boundaries=((0, total_layers),), + registered_layer_indices=(tuple(meta.registered_layer_indices),), + total_num_hidden_layers=total_layers, + ) + def _nixl_handshake( self, host: str, port: int, remote_tp_size: int, + remote_pp_size: int, expected_engine_id: str, - ) -> dict[int, str]: + ) -> dict[tuple[int, int], str]: """Do a NIXL handshake with a remote instance.""" # the first time we connect to a remote agent. @@ -498,95 +611,113 @@ def _nixl_handshake( # this happens to be the same single rank_i. assert self.transfer_topo is not None p_remote_ranks = self.transfer_topo.handshake_target_ranks(remote_tp_size) - remote_rank_to_agent_name = {} + remote_rank_to_agent_name: dict[tuple[int, int], str] = {} + metadata_shards: list[NixlAgentMetadata] = [] path = make_zmq_path("tcp", host, port) with zmq_ctx(zmq.REQ, path) as sock: - for remote_rank in p_remote_ranks: - logger.debug( - "Querying metadata on path: %s at remote tp rank %s", - path, - remote_rank, - ) + for remote_pp_rank in range(remote_pp_size): + for remote_rank in p_remote_ranks: + logger.debug( + "Querying metadata on path: %s at remote pp rank %s, " + "tp rank %s", + path, + remote_pp_rank, + remote_rank, + ) - start_time = time.perf_counter() - # Send query for the request. - msg = msgspec.msgpack.encode((GET_META_MSG, remote_rank)) - # Set receive timeout to 5 seconds to avoid hanging on dead server - sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds - sock.send(msg) - handshake_bytes = sock.recv() + start_time = time.perf_counter() + # Send query for the request. + msg = msgspec.msgpack.encode( + (GET_META_MSG, remote_pp_rank, remote_rank) + ) + # Set receive timeout to 5 seconds to avoid hanging on dead server + sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds + sock.send(msg) + handshake_bytes = sock.recv() - # Decode handshake payload to get compatibility hash - handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload) - try: - handshake_payload = handshake_decoder.decode(handshake_bytes) - except (msgspec.DecodeError, msgspec.ValidationError) as e: - raise RuntimeError( - f"Failed to decode NixlHandshakePayload. This likely indicates " - f"an incompatibility between connector version. Error: {e}" - ) from e - - got_metadata_time = time.perf_counter() - logger.debug( - "NIXL handshake: get metadata took: %s", - got_metadata_time - start_time, - ) + # Decode handshake payload to get compatibility hash + handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload) + try: + handshake_payload = handshake_decoder.decode(handshake_bytes) + except (msgspec.DecodeError, msgspec.ValidationError) as e: + raise RuntimeError( + "Failed to decode NixlHandshakePayload. This likely " + "indicates an incompatibility between connector " + f"version. Error: {e}" + ) from e + + got_metadata_time = time.perf_counter() + logger.debug( + "NIXL handshake: get metadata took: %s", + got_metadata_time - start_time, + ) - # Check compatibility hash BEFORE decoding agent metadata - assert self.compat_hash is not None - if ( - self.enforce_compat_hash - and handshake_payload.compatibility_hash != self.compat_hash - ): - raise RuntimeError( - f"NIXL compatibility hash mismatch. " - f"Local: {self.compat_hash}, " - f"Remote: {handshake_payload.compatibility_hash}. " - f"Prefill and decode instances have incompatible " - f"configurations. This may be due to: different vLLM versions," - f" models, dtypes, KV cache layouts, attention backends, etc. " - f"Both instances must use identical configurations." - f"Disable this check using " - f'--kv-transfer-config \'{{"kv_connector_extra_config": ' - f'{{"enforce_handshake_compat": false}}}}\'' + # Check compatibility hash BEFORE decoding agent metadata + assert self.compat_hash is not None + if ( + self.enforce_compat_hash + and handshake_payload.compatibility_hash != self.compat_hash + ): + raise RuntimeError( + f"NIXL compatibility hash mismatch. " + f"Local: {self.compat_hash}, " + f"Remote: {handshake_payload.compatibility_hash}. " + "Prefill and decode instances have incompatible " + "configurations. This may be due to: different vLLM " + "versions, models, dtypes, KV cache layouts, attention " + "backends, etc. Both instances must use identical " + "configurations. Disable this check using " + "--kv-transfer-config " + '\'{"kv_connector_extra_config": ' + '{"enforce_handshake_compat": false}}\'' + ) + + logger.info( + "NIXL compatibility check passed (hash: %s)", + handshake_payload.compatibility_hash, ) - logger.info( - "NIXL compatibility check passed (hash: %s)", - handshake_payload.compatibility_hash, - ) + # Decode agent metadata + metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + try: + metadata = metadata_decoder.decode( + handshake_payload.agent_metadata_bytes + ) + except (msgspec.DecodeError, msgspec.ValidationError) as e: + # This should not happen if hash matched + raise RuntimeError( + f"Failed to decode NixlAgentMetadata. Error: {e}" + ) from e + + # Ensure engine id matches. + if metadata.engine_id != expected_engine_id: + raise RuntimeError( + f"Remote NIXL agent engine ID mismatch. " + f"Expected {expected_engine_id}," + f"received {metadata.engine_id}." + ) - # Decode agent metadata - metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - try: - metadata = metadata_decoder.decode( - handshake_payload.agent_metadata_bytes + # Register Remote agent. + remote_agent_name = self.add_remote_agent( + metadata, + remote_rank, + remote_tp_size, + remote_pp_rank=remote_pp_rank, + remote_pp_size=remote_pp_size, ) - except (msgspec.DecodeError, msgspec.ValidationError) as e: - # This should not happen if hash matched - raise RuntimeError( - f"Failed to decode NixlAgentMetadata. Error: {e}" - ) from e - - # Ensure engine id matches. - if metadata.engine_id != expected_engine_id: - raise RuntimeError( - f"Remote NIXL agent engine ID mismatch. " - f"Expected {expected_engine_id}," - f"received {metadata.engine_id}." + metadata_shards.append(metadata) + setup_agent_time = time.perf_counter() + logger.debug( + "NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time, ) - - # Register Remote agent. - remote_agent_name = self.add_remote_agent( - metadata, remote_rank, remote_tp_size - ) - setup_agent_time = time.perf_counter() - logger.debug( - "NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time, - ) - remote_rank_to_agent_name[remote_rank] = remote_agent_name + remote_rank_to_agent_name[(remote_pp_rank, remote_rank)] = ( + remote_agent_name + ) + self._pp_layer_map[expected_engine_id] = self._build_pp_layer_map_from_metas( + metadata_shards, remote_pp_size + ) return remote_rank_to_agent_name def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: @@ -701,7 +832,8 @@ def _ensure_handshake( host: str, port: int, tp_size: int, - ) -> Future[dict[int, str]] | None: + pp_size: int, + ) -> Future[dict[tuple[int, int], str]] | None: """ Ensure a handshake is in-flight (or already done) for *engine_id*. @@ -712,7 +844,7 @@ def _ensure_handshake( Failures to handshake are logged and the request is marked as failed. """ with self._handshake_lock: - if engine_id in self._remote_agents: + if engine_id in self._remote_agents and engine_id in self._pp_layer_map: return None fut = self._handshake_futures.get(engine_id) if fut is not None: @@ -722,11 +854,12 @@ def _ensure_handshake( host, port, tp_size, + pp_size, engine_id, ) self._handshake_futures[engine_id] = fut - def done_callback(f: Future[dict[int, str]], eid=engine_id): + def done_callback(f: Future[dict[tuple[int, int], str]], eid=engine_id): with self._handshake_lock: del self._handshake_futures[eid] try: @@ -742,6 +875,15 @@ def done_callback(f: Future[dict[int, str]], eid=engine_id): fut.add_done_callback(done_callback) return fut + def _handshake_complete(self, engine_id: EngineId, pp_size: int) -> bool: + pp_map = self._pp_layer_map.get(engine_id) + return ( + engine_id in self._remote_agents + and engine_id not in self._handshake_futures + and pp_map is not None + and pp_map.pp_size == pp_size + ) + def _background_nixl_handshake( self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta ): @@ -752,6 +894,7 @@ def _background_nixl_handshake( meta.remote.host, meta.remote.port, meta.tp_size, + meta.pp_size, ) if fut is None: # Already handshaked — only happens if caller does not pre-check. @@ -804,6 +947,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.compat_hash = compute_nixl_compatibility_hash( self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks ) + pp_size = self.vllm_config.parallel_config.pipeline_parallel_size + if self.transfer_topo.cross_layers_blocks and pp_size > 1: + raise RuntimeError( + "cross-layer-blocks mode is not supported with " + "pipeline_parallel_size > 1 yet." + ) + if self._has_mamba and pp_size > 1: + # Per-shard descriptor layouts for hybrid (Mamba/SSM) producers + # need mamba_region_count / mamba_region_group_ids tracking that + # the consumer descriptor builder does not yet implement. Follow-up + # PR adds this. + raise RuntimeError( + "Hybrid (Mamba/SSM) models are not yet supported with " + "pipeline_parallel_size > 1 over NIXL PD disaggregation." + ) + pp_rank = get_pp_group().rank_in_group + start_layer, end_layer = self.model_config.get_layers_start_end_indices( + self.vllm_config.parallel_config + ) if self.use_host_buffer: self.initialize_host_xfer_buffer(kv_caches=kv_caches) @@ -829,7 +991,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): caches_data = [] # With hybrid allocator, layers can share a kv cache tensor - seen_base_addresses = [] + seen_base_addresses: list[int] = [] # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -844,12 +1006,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Enable different block lengths for different layers *only* when MLA is used. # This is not used for SSM layers, which use the counterpart `mamba_ssm_size`. self.block_len_per_layer = list[int]() + seen_layer_indices: list[int] = [] + seen_layer_names: list[str] = [] for layer_name, cache_or_caches in xfer_buffers.items(): # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. # However, physical page_size may differ when kernel requires a specific # block size. This leads to SSM and FA layers having different num_blocks. # `_physical_blocks_per_logical_kv_block` ratio is used to adjust for this. + layer_index = extract_layer_index(layer_name) layer_spec = self._layer_specs[layer_name] if isinstance(layer_spec, UniformTypeKVCacheSpecs): # MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs @@ -905,6 +1070,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) else: self.block_len_per_layer.append(physical_page_size) + seen_layer_indices.append(layer_index) + seen_layer_names.append(layer_name) if cache.shape[0] != num_blocks: raise AssertionError( @@ -939,8 +1106,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): "Different block lengths collected: %s", set(self.block_len_per_layer) ) assert len(self.block_len_per_layer) == len(seen_base_addresses) - - self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses + assert all(start_layer <= idx < end_layer for idx in seen_layer_indices) + self._local_kv_cache_key = (pp_rank, self.tp_rank) + self.local_seen_layer_names = seen_layer_names + self.kv_caches_base_addr[self.engine_id][self._local_kv_cache_key] = ( + seen_base_addresses + ) self.num_regions = len(caches_data) if self.transfer_topo.virtually_split_kv_in_blocks: @@ -980,17 +1151,24 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): set(self.block_len_per_layer), ) - # Register local/src descr for NIXL xfer. - self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = ( - self.register_local_xfer_handler(self.block_size) - ) - # After KV Caches registered, listen for new connections. + num_regions = len(seen_base_addresses) + logger.info( + "Registering KV_Caches. pp_rank=%d/%d, layers=[%d, %d), " + "registered_regions=%d", + pp_rank, + pp_size, + start_layer, + end_layer, + num_regions, + ) agent_metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), device_id=self.device_id, - kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank], + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][ + self._local_kv_cache_key + ], num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, kv_cache_layout=self.kv_cache_layout @@ -1002,6 +1180,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): physical_blocks_per_logical_kv_block=( self._physical_blocks_per_logical_kv_block ), + pp_rank=pp_rank, + pp_size=pp_size, + start_layer=start_layer, + end_layer=end_layer, + registered_layer_indices=seen_layer_indices, + registered_layer_names=seen_layer_names, ) # Wrap metadata in payload with hash for defensive decoding assert self.compat_hash is not None @@ -1100,12 +1284,13 @@ def _build_fa_local( self, base_addresses: list[int], block_size_ratio: int, + local_region_indices: list[int], ) -> list[tuple[int, int, int]]: """Build local FA descriptors for all layers.""" assert self.transfer_topo is not None num_blocks = self.num_blocks * block_size_ratio result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate(base_addresses): + for i, base_addr in zip(local_region_indices, base_addresses): kv_block_len = ( self.get_backend_aware_kv_block_len( layer_idx=i, first_split=True, mamba_view=False @@ -1137,6 +1322,7 @@ def _build_fa_remote( plan: TPMapping, nixl_agent_meta: NixlAgentMetadata, block_size_ratio: int, + local_region_indices: list[int], ) -> list[tuple[int, int, int]]: """Build remote FA descriptors for all layers.""" assert self.transfer_topo is not None @@ -1147,9 +1333,10 @@ def _build_fa_remote( num_blocks = nixl_agent_meta.num_blocks result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + local_region_idx = local_region_indices[i] # Read our whole local region size from remote.. local_block_len = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=False + layer_idx=local_region_idx, first_split=True, mamba_view=False ) remote_kv_block_len = local_block_len // block_size_ratio if block_size_ratio > 1: @@ -1170,7 +1357,7 @@ def _build_fa_remote( if self.transfer_topo.virtually_split_kv_in_blocks: # With FlashInfer index V separately to allow head splitting. second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=False + layer_idx=local_region_idx, first_split=False, mamba_view=False ) second_split = second_split // num_attn_reads for block_id in range(num_blocks): @@ -1184,6 +1371,8 @@ def _build_fa_remote( def register_local_xfer_handler( self, block_size: int, + *, + registered_layer_names: list[str] | None = None, ) -> tuple[int, list[tuple[int, int, int]]]: """ Function used for register local xfer handler with local block_size or @@ -1198,9 +1387,24 @@ def register_local_xfer_handler( """ assert self.transfer_topo is not None block_size_ratio = self.block_size // block_size - local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank] + # PP-aware: when registered_layer_names is provided, register only the + # local regions matching the producer shard's layers (in the producer's + # advertised order). Otherwise register all local regions. + if registered_layer_names is None: + local_base_addresses = self._get_local_base_addresses() + local_region_indices = list(range(len(local_base_addresses))) + else: + local_region_indices = self._local_region_indices_for_layer_names( + registered_layer_names + ) + local_base_addresses_all = self._get_local_base_addresses() + local_base_addresses = [ + local_base_addresses_all[i] for i in local_region_indices + ] - blocks_data = self._build_fa_local(local_base_addresses, block_size_ratio) + blocks_data = self._build_fa_local( + local_base_addresses, block_size_ratio, local_region_indices + ) logger.debug( "Created %s blocks for src engine %s and rank %s on device id %s", len(blocks_data), @@ -1230,6 +1434,9 @@ def add_remote_agent( nixl_agent_meta: NixlAgentMetadata, remote_tp_rank: int = 0, remote_tp_size: int = 1, + *, + remote_pp_rank: int = 0, + remote_pp_size: int = 1, ) -> str: """ Add the remote NIXL agent and prepare the descriptors for reading cache @@ -1275,19 +1482,24 @@ def add_remote_agent( tp_ratio < 0 (P_TP > D_TP) are supported by the 3-read transfer. """ # noqa: E501 engine_id = nixl_agent_meta.engine_id + shard_key = (remote_pp_rank, remote_tp_rank) # TODO re-evaluate refreshing for scaling/recovery - if remote_tp_rank in self._remote_agents.get(engine_id, {}): + if shard_key in self._remote_agents.get(engine_id, {}): logger.debug( - "Remote agent with engine_id %s and rank" - "%s already exchanged metadata, skip handshake.", + "Remote agent with engine_id %s, pp rank %s, tp rank %s " + "already exchanged metadata, skip handshake.", engine_id, + remote_pp_rank, remote_tp_rank, ) - return self._remote_agents[engine_id][remote_tp_rank] + return self._remote_agents[engine_id][shard_key] ### Register remote engine in TransferTopology (idempotent). assert self.transfer_topo is not None transfer_topo = self.transfer_topo + assert nixl_agent_meta.pp_rank == remote_pp_rank + assert nixl_agent_meta.pp_size == remote_pp_size + assert len(nixl_agent_meta.registered_layer_indices) > 0 physical_blocks_per_logical = ( nixl_agent_meta.physical_blocks_per_logical_kv_block ) @@ -1296,11 +1508,17 @@ def add_remote_agent( remote_block_size=nixl_agent_meta.block_size, remote_block_len=nixl_agent_meta.block_lens[0], remote_physical_blocks_per_logical=physical_blocks_per_logical, + remote_pp_rank=remote_pp_rank, + start_layer=nixl_agent_meta.start_layer, + end_layer=nixl_agent_meta.end_layer, ) transfer_topo.register_remote_engine(engine_id, transfer_info) - logger.info("Transfer plan: %s", transfer_topo.describe(engine_id)) + logger.info( + "Transfer plan: %s", transfer_topo.describe(engine_id, remote_pp_rank) + ) - self.tp_mappings[engine_id] = compute_tp_mapping( + plan_key = (engine_id, remote_pp_rank) + self.tp_mappings[plan_key] = compute_tp_mapping( transfer_topology=transfer_topo, remote_tp_size=remote_tp_size, group_spec_types=self._group_spec_types, @@ -1321,47 +1539,85 @@ def add_remote_agent( if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + else: + assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks # Keep track of remote agent kv caches base addresses. - self.kv_caches_base_addr[engine_id][remote_tp_rank] = ( + self.kv_caches_base_addr[engine_id][shard_key] = ( nixl_agent_meta.kv_caches_base_addr ) - self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) + self._remote_agent_metadata[engine_id][shard_key] = nixl_agent_meta + self._try_update_pp_layer_map(engine_id, remote_pp_size) + self._validate_remote_agent_handshake( + nixl_agent_meta, + remote_pp_rank, + remote_pp_size, + remote_tp_size, + ) # This is 1 when P and D `--tensor-parallel-size` match. Otherwise, # this is the ratio between the two sizes. tp_ratio = transfer_topo.tp_ratio(remote_tp_size) logger.debug( - "Registering remote agent (%s, rank %s) memory regions with tp_ratio %s", + "Registering remote agent (%s, pp rank %s, tp rank %s) memory " + "regions with tp_ratio %s", engine_id, + remote_pp_rank, remote_tp_rank, tp_ratio, ) - plan = self.tp_mappings[engine_id] + plan = self.tp_mappings[plan_key] + + # PP-aware: lazily register a local xfer handler for this producer + # shard's layers + block size (idempotent across (engine, pp_rank, + # block_size)). Replaces main's eager call in register_kv_caches plus + # the `block_size_ratio > 1` follow-up. + local_handle_key = (engine_id, remote_pp_rank, nixl_agent_meta.block_size) + if local_handle_key not in self.src_xfer_handles_by_remote: + handle, blocks_data = self.register_local_xfer_handler( + nixl_agent_meta.block_size, + registered_layer_names=nixl_agent_meta.registered_layer_names, + ) + self.src_xfer_handles_by_remote[local_handle_key] = handle + self.src_blocks_data_by_remote[local_handle_key] = blocks_data + self._xfer_desc_layouts[(engine_id, remote_pp_rank, "local")] = ( + self.num_blocks * block_size_ratio, + self._region_group_ids_for_layer_names( + nixl_agent_meta.registered_layer_names + ), + ) + src_blocks_data = self.src_blocks_data_by_remote[local_handle_key] + local_num_blocks, local_region_group_ids = self._xfer_desc_layouts[ + (engine_id, remote_pp_rank, "local") + ] + local_region_indices = self._local_region_indices_for_layer_names( + nixl_agent_meta.registered_layer_names + ) ### (Optional) Register local agent memory regions. MLA is not split. + split_handle_key = (engine_id, remote_pp_rank, tp_ratio) if ( tp_ratio < 0 and not self.use_mla - and tp_ratio not in self.src_xfer_handles_by_tp_ratio + and split_handle_key not in self.src_xfer_handles_by_shard_tp_ratio ): # Remote tp_size > local tp_size: read from multiple remote ranks. # Logically "split" own regions into |tp_ratio| chunks. Mind that # we only do this once per remote tp_size (replica-friendly). - self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] + self.src_xfer_handles_by_shard_tp_ratio[split_handle_key] = [] for handle_data in self._build_local_splits_from_plan( plan, - self.src_blocks_data, - self.num_descs, + src_blocks_data, + len(local_region_group_ids) * local_num_blocks, ): descs = self.nixl_wrapper.get_xfer_descs( handle_data, self.nixl_memory_type ) handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) - self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) + self.src_xfer_handles_by_shard_tp_ratio[split_handle_key].append(handle) ### Register remote agent memory regions # With homogeneous TP, D pulls the whole kv cache from corresponding rank. With @@ -1374,11 +1630,14 @@ def add_remote_agent( plan, nixl_agent_meta, block_size_ratio, + local_region_indices, ) logger.debug( - "Created %s blocks for dst engine %s with remote rank %s and local rank %s", + "Created %s blocks for dst engine %s with remote pp rank %s, " + "remote tp rank %s and local rank %s", len(blocks_data), engine_id, + remote_pp_rank, remote_tp_rank, self.tp_rank, ) @@ -1398,21 +1657,26 @@ def add_remote_agent( # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) - self.dst_xfer_side_handles[engine_id][remote_tp_rank] = ( + self.dst_xfer_side_handles[engine_id][shard_key] = ( self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs) ) - if block_size_ratio > 1: - # when prefill with smaller block_size, we need to init a - # new handler with same block_len to match - self.src_xfer_handles_by_block_size[nixl_agent_meta.block_size] = ( - self.register_local_xfer_handler(nixl_agent_meta.block_size)[0] - ) + self._xfer_desc_layouts[(engine_id, remote_pp_rank, "remote")] = ( + nixl_agent_meta.num_blocks, + self._region_group_ids_for_layer_names( + nixl_agent_meta.registered_layer_names + ), + ) + self._remote_agents[engine_id][shard_key] = remote_agent_name return remote_agent_name def _validate_remote_agent_handshake( - self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int + self, + nixl_agent_meta: NixlAgentMetadata, + remote_pp_rank: int, + remote_pp_size: int, + remote_tp_size: int, ): """ Validate the remote agent handshake metadata ensuring the @@ -1421,7 +1685,33 @@ def _validate_remote_agent_handshake( remote_engine_id = nixl_agent_meta.engine_id assert self.transfer_topo is not None - remote_info = self.transfer_topo.get_engine_info(remote_engine_id) + assert nixl_agent_meta.pp_rank == remote_pp_rank + assert nixl_agent_meta.pp_size == remote_pp_size + total_layers = self.model_config.get_total_num_hidden_layers() + assert ( + 0 <= nixl_agent_meta.start_layer < nixl_agent_meta.end_layer <= total_layers + ) + assert ( + len(nixl_agent_meta.kv_caches_base_addr) + == len(nixl_agent_meta.block_lens) + == len(nixl_agent_meta.registered_layer_indices) + == len(nixl_agent_meta.registered_layer_names) + ) + assert all( + nixl_agent_meta.start_layer <= global_layer_idx < nixl_agent_meta.end_layer + for global_layer_idx in nixl_agent_meta.registered_layer_indices + ) + assert nixl_agent_meta.registered_layer_indices == [ + extract_layer_index(name) for name in nixl_agent_meta.registered_layer_names + ] + # Will raise if any producer layer-name has no matching local region. + local_region_indices = self._local_region_indices_for_layer_names( + nixl_agent_meta.registered_layer_names + ) + + remote_info = self.transfer_topo.get_engine_info( + remote_engine_id, remote_pp_rank + ) assert remote_info.remote_tp_size == remote_tp_size tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size) @@ -1433,7 +1723,10 @@ def _validate_remote_agent_handshake( # MLA models do not need to handle kv replication. if not self.use_mla and not self._has_mamba: assert not ( - tp_ratio < 0 and self.transfer_topo.is_kv_replicated(remote_engine_id) + tp_ratio < 0 + and self.transfer_topo.is_kv_replicated( + remote_engine_id, remote_pp_rank + ) ) remote_physical_per_logical = ( @@ -1503,7 +1796,9 @@ def _validate_remote_agent_handshake( if ( abs(tp_ratio) != 1 and not self.use_mla - and not self.transfer_topo.is_kv_replicated(remote_engine_id) + and not self.transfer_topo.is_kv_replicated( + remote_engine_id, remote_pp_rank + ) and kv_cache_layout != "HND" and not self.enable_permute_local_kv ): @@ -1514,15 +1809,19 @@ def _validate_remote_agent_handshake( # Block len can only vary across layers when using MLA. remote_block_len = nixl_agent_meta.block_lens[0] - if self.use_mla or self.transfer_topo.is_kv_replicated(remote_engine_id): + remote_replicates_kv = self.transfer_topo.is_kv_replicated( + remote_engine_id, remote_pp_rank + ) + if self.use_mla or remote_replicates_kv: # With replicated KV cache, only the number of blocks can differ. # TODO (ZhanqiuHu): For mamba models, validate FA and mamba # block_lens separately. if not self._has_mamba: - for i in range(len(self.block_len_per_layer)): + for j, remote_block_len_j in enumerate(nixl_agent_meta.block_lens): + local_region_idx = local_region_indices[j] assert ( - self.block_len_per_layer[i] // block_size_ratio - == nixl_agent_meta.block_lens[i] + self.block_len_per_layer[local_region_idx] // block_size_ratio + == remote_block_len_j ), "KV cache sizes must match between P and D when replicated" else: # When MLA is not used, this is a list of the same block length @@ -1535,32 +1834,27 @@ def _validate_remote_agent_handshake( # max(attn_page, mamba_page), so the linear tp_ratio scaling # assumption only holds for pure-attention models. if not self._has_mamba: - if tp_ratio > 0: - assert ( - remote_block_len - == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio - ), ( - "Remote P worker KV layer cache must be of shape [2, N," - " local_kv_heads*tp_ratio, page_size, head_dim] and " - "same dtype." - ) - else: - assert block_size_ratio == 1, ( - "Different local/remote block sizes are not supported" - " when P TP > D TP." - ) - assert remote_block_len == self.block_len_per_layer[0] // ( - -tp_ratio - ), ( - "Remote P worker KV layer cache must be of shape [2, N," - " local_kv_heads/tp_ratio, page_size, head_dim] and " - "same dtype." + for j, remote_block_len_j in enumerate(nixl_agent_meta.block_lens): + local_region_idx = local_region_indices[j] + if tp_ratio > 0: + expected = ( + self.block_len_per_layer[local_region_idx] * tp_ratio + ) // block_size_ratio + else: + assert block_size_ratio == 1, ( + "Different local/remote block sizes are not supported" + " when P TP > D TP." + ) + expected = self.block_len_per_layer[local_region_idx] // ( + -tp_ratio + ) + assert remote_block_len_j == expected, ( + "Remote P worker KV layer cache shape is incompatible " + "with the local decode worker." ) # TP workers that handhshake with same remote have same #blocks. assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks - # Same number of regions/~layers. - assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" @@ -1743,7 +2037,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: self.sync_recved_kv_to_device(req_id, meta) # post processing for heteroblocksize - remote_info = self.transfer_topo.get_engine_info(meta.remote.engine_id) + remote_info = self.transfer_topo.get_engine_info(meta.remote.engine_id, 0) block_size_ratio = self.transfer_topo.block_size_ratio( remote_info.remote_block_size ) @@ -1953,10 +2247,10 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): ) # always store metadata for failure recovery self._recving_metadata[req_id] = meta - if remote_engine_id not in self._remote_agents: + if not self._handshake_complete(remote_engine_id, meta.pp_size): # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: - if remote_engine_id not in self._remote_agents: + if not self._handshake_complete(remote_engine_id, meta.pp_size): self._background_nixl_handshake(req_id, remote_engine_id, meta) continue @@ -2000,7 +2294,11 @@ def _send_heartbeats(self, metadata: NixlConnectorMetadata) -> None: # the **next** heartbeat for this remote can go through. if ( self._ensure_handshake( - engine_id, hb_info.host, hb_info.port, hb_info.tp_size + engine_id, + hb_info.host, + hb_info.port, + hb_info.tp_size, + hb_info.pp_size, ) is not None ): @@ -2021,87 +2319,109 @@ def _send_heartbeats(self, metadata: NixlConnectorMetadata) -> None: def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): assert meta.remote is not None and self.transfer_topo is not None engine_id = meta.remote.engine_id - plan = self.tp_mappings[engine_id] - remote_info = self.transfer_topo.get_engine_info(engine_id) - tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) + # Callers gate on _handshake_complete, which already populates _pp_layer_map. + pp_map = self._pp_layer_map[engine_id] - meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( - meta.remote.block_ids, - remote_info.remote_physical_blocks_per_logical, - ) - remote_block_ids = meta.remote.block_ids + # Expand to kernel block ids per remote PP stage (each may have a + # different physical_blocks_per_logical), not once with rank 0's. + # Keep meta.remote logical: it's reused on retry/failure paths. + logical_remote_block_ids = meta.remote.block_ids local_block_ids = meta.local_physical_block_ids - num_groups = len(local_block_ids) - read_specs = [ - ReadSpec( - remote_rank=rank, - local_block_ids=[ - list(local_block_ids[g]) - if rank in plan.source_ranks_per_group[g] - else [] - for g in range(num_groups) - ], - remote_block_ids=[ - list(remote_block_ids[g]) - if rank in plan.source_ranks_per_group[g] - else [] - for g in range(num_groups) - ], - ) - for rank in plan.all_source_ranks - ] - - # D may have to perform multiple reads from different remote ranks. - # MLA opt: when P TP > D TP, only a single read is executed for - # the first remote rank (cache is duplicated).. - if self.use_mla and tp_ratio < 0: - assert len(read_specs) == 1 - - for i, spec in enumerate(read_specs): - remote_block_size = remote_info.remote_block_size - logger.debug( - "Remote agent %s available, calling _read_blocks" - " on remote rank %s with remote block size %s for req %s", - meta.remote.engine_id, - spec.remote_rank, - remote_block_size, - req_id, + full_prefix_hit = len(local_block_ids) == 0 + + for remote_pp_rank in range(pp_map.pp_size): + plan = self.tp_mappings[(engine_id, remote_pp_rank)] + remote_info = self.transfer_topo.get_engine_info(engine_id, remote_pp_rank) + remote_block_ids = self._logical_to_remote_kernel_block_ids( + logical_remote_block_ids, + remote_info.remote_physical_blocks_per_logical, ) - # Get side handles. - if tp_ratio < 0 and not self.use_mla: - assert remote_block_size == self.block_size - # Remote tp_size > local tp_size: we must perform multiple - # reads. Get the memory chunk onto which we will write to. - local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i] - else: - # Single read from remote, we write to the whole memory region. - # Also handle remote block size different from local block size. - local_xfer_side_handle = self.src_xfer_handles_by_block_size[ - remote_block_size - ] - - # Destination handle: remote_engine_id -> remote_rank -> handle. - remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][ - spec.remote_rank + tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) + num_groups = len(local_block_ids) + read_specs = [ + ReadSpec( + remote_rank=rank, + local_block_ids=[ + list(local_block_ids[g]) + if rank in plan.source_ranks_per_group[g] + else [] + for g in range(num_groups) + ], + remote_block_ids=[ + list(remote_block_ids[g]) + if rank in plan.source_ranks_per_group[g] + else [] + for g in range(num_groups) + ], + ) + for rank in plan.all_source_ranks ] - self._read_blocks( - read_spec=spec, - request_id=req_id, - dst_engine_id=meta.remote.engine_id, - remote_request_id=meta.remote.request_id, - local_xfer_side_handle=local_xfer_side_handle, - remote_xfer_side_handle=remote_xfer_side_handle, - ) + # D may have to perform multiple reads from different remote ranks. + # MLA opt: when P TP > D TP, only a single read is executed for + # the first remote rank (cache is duplicated).. + if self.use_mla and tp_ratio < 0: + assert len(read_specs) == 1 - if self.use_mla and tp_ratio < 0 and read_specs: - # ..but we still need to notify the other remote ranks that we - # have the blocks we need so they can update the request state. - notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() - remote_agents = self._remote_agents[meta.remote.engine_id] - for rank_to_notify, agent in remote_agents.items(): - if rank_to_notify != read_specs[0].remote_rank: - self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) + for i, spec in enumerate(read_specs): + remote_block_size = remote_info.remote_block_size + logger.debug( + "Remote agent %s available, calling _read_blocks on remote " + "pp rank %s, tp rank %s with remote block size %s for req %s", + meta.remote.engine_id, + remote_pp_rank, + spec.remote_rank, + remote_block_size, + req_id, + ) + # Get side handles. + if tp_ratio < 0 and not self.use_mla: + # Remote tp_size > local tp_size: we must perform multiple + # reads. Get the memory chunk onto which we will write to. + assert remote_block_size == self.block_size + split_key = (engine_id, remote_pp_rank, tp_ratio) + local_xfer_side_handle = self.src_xfer_handles_by_shard_tp_ratio[ + split_key + ][i] + else: + # Single read from remote, we write to the whole memory region. + # Also handle remote block size different from local block size. + local_xfer_side_handle = self.src_xfer_handles_by_remote[ + (engine_id, remote_pp_rank, remote_block_size) + ] + + # Destination handle: + # remote_engine_id -> (remote_pp_rank, remote_rank) -> handle. + remote_xfer_side_handle = self.dst_xfer_side_handles[ + meta.remote.engine_id + ][(remote_pp_rank, spec.remote_rank)] + + self._read_blocks( + read_spec=spec, + request_id=req_id, + dst_engine_id=meta.remote.engine_id, + remote_request_id=meta.remote.request_id, + remote_pp_rank=remote_pp_rank, + local_xfer_side_handle=local_xfer_side_handle, + remote_xfer_side_handle=remote_xfer_side_handle, + ) + + # ..but we still need to notify the other remote ranks that we + # have the blocks we need so they can update the request state. + if self.use_mla and tp_ratio < 0: + notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() + remote_agents = self._remote_agents[meta.remote.engine_id] + for (pp_rank, rank_to_notify), agent in remote_agents.items(): + if ( + pp_rank == remote_pp_rank + and rank_to_notify != spec.remote_rank + ): + self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) + + if full_prefix_hit: + # Notification-only path: the scheduler did not wait for remote KV, + # so there is no recv completion to report. + self._recving_metadata.pop(req_id, None) def _read_blocks( self, @@ -2109,6 +2429,7 @@ def _read_blocks( dst_engine_id: str, request_id: str, remote_request_id: str, + remote_pp_rank: int, local_xfer_side_handle: int, remote_xfer_side_handle: int, ): @@ -2120,8 +2441,7 @@ def _read_blocks( remote_rank = read_spec.remote_rank local_block_ids = read_spec.local_block_ids remote_block_ids = read_spec.remote_block_ids - - remote_info = self.transfer_topo.get_engine_info(dst_engine_id) + remote_info = self.transfer_topo.get_engine_info(dst_engine_id, remote_pp_rank) block_size_ratio = self.transfer_topo.block_size_ratio( remote_info.remote_block_size ) @@ -2150,6 +2470,7 @@ def _read_blocks( ] local_block_ids = [local_block_ids_mapped] if local_block_ids_mapped else [] remote_block_ids = [remote_block_ids0] + # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the @@ -2167,7 +2488,9 @@ def _read_blocks( # just notify P worker that we have the blocks we need. if len(local_block_ids) == 0: # A full prefix cache hit is indicated with an empty list. - agent_name = self._remote_agents[dst_engine_id][remote_rank] + agent_name = self._remote_agents[dst_engine_id][ + (remote_pp_rank, remote_rank) + ] try: self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) except Exception as e: @@ -2178,6 +2501,7 @@ def _read_blocks( req_id=request_id, error=e, dst_engine_id=dst_engine_id, + remote_pp_rank=remote_pp_rank, remote_rank=remote_rank, remote_agent_name=agent_name, ) @@ -2199,17 +2523,17 @@ def _read_blocks( # workers will issue xfers to parts of the P worker remote kv caches. # Get descs ids. - remote_block_descs_ids = self._compute_desc_ids( - block_ids=remote_block_ids, - dst_num_blocks=self.dst_num_blocks[dst_engine_id], - block_size_ratio=None, - physical_blocks_per_logical=remote_info.remote_physical_blocks_per_logical, + remote_block_descs_ids = self._get_block_descs_ids_for_shard( + dst_engine_id, + remote_pp_rank, + "remote", + remote_block_ids, ) - local_block_descs_ids = self._compute_desc_ids( - block_ids=local_block_ids, - dst_num_blocks=self.dst_num_blocks[self.engine_id], - block_size_ratio=block_size_ratio, - physical_blocks_per_logical=self._physical_blocks_per_logical_kv_block, + local_block_descs_ids = self._get_block_descs_ids_for_shard( + dst_engine_id, + remote_pp_rank, + "local", + local_block_ids, ) assert len(local_block_descs_ids) == len(remote_block_descs_ids) @@ -2239,10 +2563,32 @@ def _read_blocks( msg="Marking blocks as invalid", error=e, dst_engine_id=dst_engine_id, + remote_pp_rank=remote_pp_rank, remote_rank=remote_rank, ) self._handle_failed_transfer(request_id, handle) + def _get_block_descs_ids_for_shard( + self, + engine_id: str, + remote_pp_rank: int, + side: str, + block_ids: BlockIds, + ) -> np.ndarray: + """Get descriptor IDs relative to a shard-local prepared dlist.""" + num_blocks, region_group_ids = self._xfer_desc_layouts[ + (engine_id, remote_pp_rank, side) + ] + desc_ids = [] + for region_id, group_id in enumerate(region_group_ids): + group_arr = np.asarray(block_ids[group_id], dtype=np.int64) + if group_arr.size == 0: + continue + desc_ids.append(region_id * num_blocks + group_arr) + if not desc_ids: + return np.empty(0, dtype=np.int64) + return np.concatenate(desc_ids) + def get_mapped_blocks( self, block_ids: np.ndarray, block_size_ratio: int ) -> np.ndarray: @@ -2463,13 +2809,13 @@ def shutdown(self): for handle in handles: self.nixl_wrapper.release_xfer_handle(handle) self._recving_transfers.clear() - for handle in self.src_xfer_handles_by_block_size.values(): + for handle in self.src_xfer_handles_by_remote.values(): self.nixl_wrapper.release_dlist_handle(handle) - self.src_xfer_handles_by_block_size.clear() - for handles in self.src_xfer_handles_by_tp_ratio.values(): + self.src_xfer_handles_by_remote.clear() + for handles in self.src_xfer_handles_by_shard_tp_ratio.values(): for handle in handles: self.nixl_wrapper.release_dlist_handle(handle) - self.src_xfer_handles_by_tp_ratio.clear() + self.src_xfer_handles_by_shard_tp_ratio.clear() for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): for dst_xfer_side_handle in dst_xfer_side_handles.values(): self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) From 320017258b288aabd3adcf471bf99a9194acbfd2 Mon Sep 17 00:00:00 2001 From: zixi-qi Date: Thu, 28 May 2026 17:11:47 +0000 Subject: [PATCH 5/5] [KVConnector][NIXL] Per-layer-name HMA routing for hybrid models under PP Signed-off-by: zixi-qi --- .../test_consumer_shard_refactor.py | 3 + .../test_handshake_aggregation.py | 63 ++- .../test_hma_pp_per_layer_regions.py | 350 ++++++++++++++++ .../nixl_integration/test_pp_layer_map.py | 10 +- .../kv_connector/unit/test_nixl_connector.py | 18 + .../kv_connector/v1/nixl/metadata.py | 24 +- .../kv_connector/v1/nixl/worker.py | 389 ++++++++++++++---- 7 files changed, 753 insertions(+), 104 deletions(-) create mode 100644 tests/v1/kv_connector/nixl_integration/test_hma_pp_per_layer_regions.py diff --git a/tests/v1/kv_connector/nixl_integration/test_consumer_shard_refactor.py b/tests/v1/kv_connector/nixl_integration/test_consumer_shard_refactor.py index 50fc310c74f5..ef54d0c70bc1 100644 --- a/tests/v1/kv_connector/nixl_integration/test_consumer_shard_refactor.py +++ b/tests/v1/kv_connector/nixl_integration/test_consumer_shard_refactor.py @@ -91,6 +91,9 @@ def _make_worker(total_layers: int = 32) -> NixlConnectorWorker: worker.local_seen_layer_names = [ f"model.layers.{layer}.self_attn" for layer in range(total_layers) ] + worker._local_layer_name_to_region_indices = defaultdict(list) + for idx, name in enumerate(worker.local_seen_layer_names): + worker._local_layer_name_to_region_indices[name].append(idx) worker._layer_name_to_kv_group_index = { layer_name: 0 for layer_name in worker.local_seen_layer_names } diff --git a/tests/v1/kv_connector/nixl_integration/test_handshake_aggregation.py b/tests/v1/kv_connector/nixl_integration/test_handshake_aggregation.py index 1e4380a66a37..f298d81158ad 100644 --- a/tests/v1/kv_connector/nixl_integration/test_handshake_aggregation.py +++ b/tests/v1/kv_connector/nixl_integration/test_handshake_aggregation.py @@ -24,6 +24,7 @@ NixlConnectorWorker, ReadSpec, ) +from vllm.v1.kv_cache_interface import FullAttentionSpec class _InlineThread: @@ -206,15 +207,8 @@ def test_background_nixl_handshake_submits_remote_pp_size(pp_size: int): assert future.add_done_callback.call_count == 2 -def test_hma_pp_assertion_guard_in_read_blocks() -> None: - """NIXL PR1 must reject HMA × PP combinations with AssertionError. - - This guard is the PR1↔PR2 split point. When NIXL PR2 lands per-layer-name - HMA × PP routing, the ``assert not self._is_hma_required`` checks inside - ``_read_blocks`` and friends will be lifted. Until then, configuring NIXL - with HMA enabled and a heterogeneous block-size remote (the path that - co-occurs under ``pp_size > 1`` with multi-group KV caches) must fail loud. - """ +def test_hma_pp_read_blocks_maps_each_kv_group() -> None: + """HMA reads map block-size ratios independently per KV group.""" import numpy as np worker = NixlConnectorWorker.__new__(NixlConnectorWorker) @@ -222,6 +216,14 @@ def test_hma_pp_assertion_guard_in_read_blocks() -> None: worker.world_size = 1 worker.block_size = 16 worker._remote_agents = {"remote-engine": {(0, 0): "agent-0-0"}} + worker._group_spec_types = (FullAttentionSpec, FullAttentionSpec) + worker.kv_cache_config = SimpleNamespace(kv_cache_groups=[object(), object()]) + worker._recving_transfers = {"req": []} + worker._log_failure = MagicMock() + worker._handle_failed_transfer = MagicMock() + worker.xfer_stats = MagicMock() + worker.nixl_wrapper = MagicMock() + worker.nixl_wrapper.make_prepped_xfer.return_value = 99 transfer_topo = MagicMock() transfer_topo.get_engine_info.return_value = SimpleNamespace( @@ -230,16 +232,35 @@ def test_hma_pp_assertion_guard_in_read_blocks() -> None: ) transfer_topo.block_size_ratio.return_value = 2 worker.transfer_topo = transfer_topo - worker.get_mapped_blocks = MagicMock(return_value=np.asarray([0, 1, 2, 3])) - - spec = ReadSpec(remote_rank=0, local_block_ids=[[0]], remote_block_ids=[[1]]) - with pytest.raises(AssertionError): - worker._read_blocks( - read_spec=spec, - request_id="req", - dst_engine_id="remote-engine", - remote_request_id="rreq", - remote_pp_rank=0, - local_xfer_side_handle=0, - remote_xfer_side_handle=0, + worker.get_mapped_blocks = NixlConnectorWorker.get_mapped_blocks.__get__(worker) + worker._apply_prefix_caching = MagicMock( + side_effect=lambda local, remote, _: ( + local, + remote, ) + ) + worker._get_block_descs_ids_for_shard = MagicMock( + side_effect=[np.asarray([10, 11, 12, 20, 21]), np.asarray([0, 1, 2, 6, 7])] + ) + + spec = ReadSpec( + remote_rank=0, + local_block_ids=[[0, 1], [3]], + remote_block_ids=[[5, 6, 7], [8, 9]], + ) + worker._read_blocks( + read_spec=spec, + request_id="req", + dst_engine_id="remote-engine", + remote_request_id="rreq", + remote_pp_rank=0, + local_xfer_side_handle=0, + remote_xfer_side_handle=0, + ) + + worker._apply_prefix_caching.assert_called_once_with( + [[0, 1, 2], [6, 7]], + [[5, 6, 7], [8, 9]], + 1, + ) + assert worker._recving_transfers["req"] == [99] diff --git a/tests/v1/kv_connector/nixl_integration/test_hma_pp_per_layer_regions.py b/tests/v1/kv_connector/nixl_integration/test_hma_pp_per_layer_regions.py new file mode 100644 index 000000000000..96817481a20d --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/test_hma_pp_per_layer_regions.py @@ -0,0 +1,350 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import defaultdict + +import msgspec + +from tests.v1.kv_connector.nixl_integration.test_consumer_shard_refactor import ( + REMOTE_ENGINE_ID, + _make_worker, +) +from tests.v1.kv_connector.nixl_integration.test_pp_layer_map import _meta +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( + NIXL_CONNECTOR_VERSION, + NixlAgentMetadata, +) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( + _make_shard_desc_layout, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + + +def _attn(layer_idx: int) -> str: + return f"model.layers.{layer_idx}.attn" + + +def _swa(layer_idx: int) -> str: + return f"model.layers.{layer_idx}.attn.swa_cache" + + +def _compressor(layer_idx: int) -> str: + return f"model.layers.{layer_idx}.attn.compressor.state_cache" + + +def _configure_hma_worker(layer_names: list[str], group_ids: list[int]): + worker = _make_worker(total_layers=128) + worker.local_seen_layer_names = layer_names + worker.block_len_per_layer = [1024] * len(layer_names) + worker.kv_caches_base_addr[worker.engine_id][worker._local_kv_cache_key] = [ + 100_000 + i * 10_000 for i in range(len(layer_names)) + ] + layer_name_to_region_indices: dict[str, list[int]] = defaultdict(list) + for idx, name in enumerate(layer_names): + layer_name_to_region_indices[name].append(idx) + worker._local_layer_name_to_region_indices = layer_name_to_region_indices + worker._layer_name_to_kv_group_index = dict(zip(layer_names, group_ids)) + worker._is_hma_required = True + return worker + + +def test_asymmetric_dsv4_pool_case_resolves_by_layer_name(): + local_layer_names = [ + _attn(15), + _compressor(13), + _swa(14), + _swa(15), + _attn(16), + _compressor(16), + _swa(16), + _swa(17), + _attn(18), + _compressor(18), + ] + worker = _configure_hma_worker(local_layer_names, [0] * len(local_layer_names)) + + producer_region_names = [ + _compressor(15), + _swa(15), + _attn(16), + _compressor(16), + _swa(16), + ] + # The old pool-subset matcher failed this shape because these names span + # two decode-side HMA pools. Per-layer regions need only exact names. + worker.local_seen_layer_names.insert(1, _compressor(15)) + worker.block_len_per_layer.insert(1, 1024) + worker.kv_caches_base_addr[worker.engine_id][worker._local_kv_cache_key].insert( + 1, 110_000 + ) + worker._layer_name_to_kv_group_index[_compressor(15)] = 0 + # Rebuild the layer-name → region index map so it reflects the insert. + worker._local_layer_name_to_region_indices = defaultdict( + list, + {name: [idx] for idx, name in enumerate(worker.local_seen_layer_names)}, + ) + + assert worker._local_region_indices_for_layer_names(producer_region_names) == [ + worker.local_seen_layer_names.index(name) for name in producer_region_names + ] + + +def test_pool_member_resolves_to_sharing_region_index(): + # Models the DeepseekV4 + PP failure: the local side pools (e.g.) L14's SWA + # cache with L16's main attention (HMA shared region), so L14's swa name is + # dedup'd out of ``local_seen_layer_names`` even though it lives in + # ``kv_caches``. The producer's PP slice ends right at L14 so it advertises + # ``model.layers.14.attn.swa_cache`` as a pool representative. The matcher + # must still route it to the local region that holds L14's SWA data. + representative_layer_names = [ + _attn(16), # local representative for the shared (c4a + swa) pool + _attn(18), # second shared-pool representative + ] + worker = _configure_hma_worker( + representative_layer_names, [0] * len(representative_layer_names) + ) + # L14's SWA and L16's main attn share an HMA region; the dedup keeps L16 + # in ``local_seen_layer_names`` but L14's swa is still part of the local + # kv_caches and must resolve to the same NIXL region. + worker._local_layer_name_to_region_indices[_swa(14)].append(0) + worker._local_layer_name_to_region_indices[_swa(16)].append(1) + worker._layer_name_to_kv_group_index[_swa(14)] = 1 + worker._layer_name_to_kv_group_index[_swa(16)] = 1 + + producer_layer_names = [ + _attn(16), + _swa(14), # producer's alone-SWA representative + _attn(18), + _swa(16), + ] + + assert worker._local_region_indices_for_layer_names(producer_layer_names) == [ + 0, + 0, + 1, + 1, + ] + + +def test_descriptor_ids_are_per_layer_and_kv_group_specific(): + layer_names = [_attn(15), _swa(15), _attn(16), _compressor(16)] + worker = _configure_hma_worker(layer_names, [0, 1, 0, 2]) + worker._xfer_desc_layouts[(REMOTE_ENGINE_ID, 1, "local")] = _make_shard_desc_layout( + num_blocks=10, + region_group_ids=(0, 1, 0, 2), + ) + + desc_ids = worker._get_block_descs_ids_for_shard( + REMOTE_ENGINE_ID, + 1, + "local", + ([1, 2], [7], [4, 5]), + ) + + assert desc_ids.tolist() == [1, 2, 17, 21, 22, 34, 35] + + +def test_expand_remote_members_routes_each_member_to_its_local_region(): + # Member-identity routing keyed by LAYER NAME: every producer member — + # incl. HMA cross-group pooled swa members — is resolved to the consumer + # region that physically holds it, independent of the producer's pooling. A + # producer region pooling attn(16) with swa(14)/swa(15) expands to three + # transfer units routed to whichever consumer regions hold them. + worker = _configure_hma_worker([_attn(16), _attn(18)], [0, 0]) + # Consumer layout (swa is kv-group 1): region 0 holds attn(16)+swa(14)+ + # swa(15); region 1 holds attn(18)+swa(16)+swa(17). + worker._member_to_local_region = { + _attn(16): 0, + _swa(14): 0, + _swa(15): 0, + _attn(18): 1, + _swa(16): 1, + _swa(17): 1, + } + worker._layer_name_to_kv_group_index.update( + { + _attn(16): 0, + _attn(18): 0, + _swa(14): 1, + _swa(15): 1, + _swa(16): 1, + _swa(17): 1, + } + ) + meta = _meta( + 0, + 14, + 19, + pp_size=2, + registered_layer_indices=[16, 18], + registered_layer_names=[_attn(16), _attn(18)], + region_members=[ + [_attn(16), _swa(14), _swa(15)], + [_attn(18), _swa(16), _swa(17)], + ], + ) + + member_local_regions, member_groups, member_meta = worker._expand_remote_members( + meta + ) + + assert member_local_regions == [0, 0, 0, 1, 1, 1] + assert member_groups == (0, 1, 1, 0, 1, 1) + # Each producer region's base addr / block len is repeated per member so the + # region-based builders emit one descriptor group per member. + r0, r1 = meta.kv_caches_base_addr[0], meta.kv_caches_base_addr[1] + assert member_meta.kv_caches_base_addr == [r0, r0, r0, r1, r1, r1] + assert len(member_meta.block_lens) == 6 + + +def test_expand_remote_members_raises_on_unknown_member(): + worker = _configure_hma_worker([_attn(16)], [0]) + worker._member_to_local_region = {_attn(16): 0} + meta = _meta( + 0, + 16, + 17, + pp_size=1, + registered_layer_indices=[16], + registered_layer_names=[_attn(16)], + region_members=[[_attn(16), _swa(14)]], # swa(14) has no local region + ) + + try: + worker._expand_remote_members(meta) + except RuntimeError as exc: + assert "no matching local region" in str(exc) + else: + raise AssertionError("expected RuntimeError for unmapped member") + + +def test_distinct_layer_names_in_same_kv_group_route_to_distinct_regions(): + # Regression: an MLA layer's main latent and its indexer/compressor cache + # can both land in kv-group 0 (UniformTypeKVCacheSpecs merges their specs), + # so a (layer_index, kv_group_index) member key is non-unique across + # regions. Member identity must key on the LAYER NAME so the two distinct + # caches route to their own consumer regions instead of collapsing onto one + # — collapsing double-writes one region and leaves the other's slots stale, + # which corrupts long-context KV under PP+HMA disaggregation. + worker = _configure_hma_worker([_attn(2), _compressor(2)], [0, 0]) + worker._member_to_local_region = {_attn(2): 0, _compressor(2): 1} + meta = _meta( + 0, + 2, + 3, + pp_size=1, + registered_layer_indices=[2, 2], + registered_layer_names=[_attn(2), _compressor(2)], + region_members=[[_attn(2)], [_compressor(2)]], + ) + + member_local_regions, member_groups, _ = worker._expand_remote_members(meta) + + # Distinct regions (0, 1) — NOT collapsed to [0, 0] as a (layer, group) key + # would. Both belong to kv-group 0. + assert member_local_regions == [0, 1] + assert member_groups == (0, 0) + + +def test_mamba_descriptor_ids_use_mamba_suffix_and_group_filter(): + layer_names = [_attn(15), _compressor(16)] + worker = _configure_hma_worker(layer_names, [0, 1]) + worker._has_mamba = True + worker._group_spec_types = (FullAttentionSpec, MambaSpec) + worker._xfer_desc_layouts[(REMOTE_ENGINE_ID, 1, "local")] = _make_shard_desc_layout( + num_blocks=10, + region_group_ids=(0, 1), + physical_blocks_per_logical=2, + mamba_region_count=8, + mamba_region_group_ids=(0, 0, 0, 0, 1, 1, 1, 1), + ) + + desc_ids = worker._get_block_descs_ids_for_shard( + REMOTE_ENGINE_ID, + 1, + "local", + ([1, 2], [3]), + ) + + assert desc_ids.tolist() == [1, 2, 43, 48, 53, 58] + + +def test_repeated_layer_name_uses_matching_occurrence_for_split_regions(): + layer_name = "model.layers.3.self_attn" + worker = _configure_hma_worker([layer_name, layer_name], [0, 0]) + + assert worker._local_region_indices_for_layer_names([layer_name, layer_name]) == [ + 0, + 1, + ] + + +def test_nixl_agent_metadata_v6_registered_layer_names_round_trip(): + meta = _meta( + 0, + 0, + 4, + pp_size=1, + registered_layer_indices=[0, 2], + registered_layer_names=[ + "model.layers.0.self_attn", + "model.layers.2.indexer", + ], + ) + + decoded = msgspec.msgpack.decode( + msgspec.msgpack.encode(meta), type=NixlAgentMetadata + ) + + assert NIXL_CONNECTOR_VERSION == 6 + assert decoded.registered_layer_names == [ + "model.layers.0.self_attn", + "model.layers.2.indexer", + ] + + +def test_nixl_agent_metadata_v6_region_members_round_trip(): + # region_members advertises, per registered NIXL region, every layer name + # sharing that region — including HMA cross-group pooled members (e.g. an + # swa_cache pooled onto a later layer's main-attn region) that are dedup'd + # out of registered_layer_names. + meta = _meta( + 0, + 0, + 4, + pp_size=1, + registered_layer_indices=[2, 4], + registered_layer_names=[_attn(2), _attn(4)], + region_members=[ + [_attn(2), _swa(0), _swa(1)], # region holds L2 attn + L0/L1 swa + [_attn(4), _swa(2), _swa(3)], # region holds L4 attn + L2/L3 swa + ], + ) + + decoded = msgspec.msgpack.decode( + msgspec.msgpack.encode(meta), type=NixlAgentMetadata + ) + + assert decoded.region_members == [ + [_attn(2), _swa(0), _swa(1)], + [_attn(4), _swa(2), _swa(3)], + ] + + +def test_shard_local_handler_uses_registered_layer_names(): + layer_names = [_attn(15), _swa(15), _attn(16)] + worker = _configure_hma_worker(layer_names, [0, 1, 0]) + worker.nixl_wrapper._next_handle = 0 + + _, blocks_data = worker.register_local_xfer_handler( + worker.block_size, + registered_layer_names=[_swa(15), _attn(16)], + ) + + assert worker._region_group_ids_for_layer_names([_swa(15), _attn(16)]) == ( + 1, + 1, + 0, + 0, + ) + assert [blocks_data[i][0] for i in (0, 8)] == [110_000, 120_000] diff --git a/tests/v1/kv_connector/nixl_integration/test_pp_layer_map.py b/tests/v1/kv_connector/nixl_integration/test_pp_layer_map.py index da874f3e7c41..83f2434b4a2b 100644 --- a/tests/v1/kv_connector/nixl_integration/test_pp_layer_map.py +++ b/tests/v1/kv_connector/nixl_integration/test_pp_layer_map.py @@ -24,6 +24,7 @@ def _meta( pp_size: int = 4, registered_layer_indices: list[int] | None = None, registered_layer_names: list[str] | None = None, + region_members: list[list[str]] | None = None, ) -> NixlAgentMetadata: if registered_layer_indices is None: registered_layer_indices = list(range(start_layer, end_layer)) @@ -31,6 +32,10 @@ def _meta( registered_layer_names = [ f"model.layers.{idx}.self_attn" for idx in registered_layer_indices ] + if region_members is None: + # Default: each advertised region holds exactly its representative + # layer name (the non-pooled / no-HMA shape). + region_members = [[name] for name in registered_layer_names] return NixlAgentMetadata( engine_id="engine", agent_metadata=b"agent", @@ -49,6 +54,7 @@ def _meta( end_layer=end_layer, registered_layer_indices=registered_layer_indices, registered_layer_names=registered_layer_names, + region_members=region_members, ) @@ -94,8 +100,8 @@ def _fake_vllm_config(pipeline_parallel_size: int = 1) -> SimpleNamespace: ) -def test_nixl_connector_version_is_bumped_to_v5(): - assert NIXL_CONNECTOR_VERSION == 5 +def test_nixl_connector_version_is_bumped_to_v6(): + assert NIXL_CONNECTOR_VERSION == 6 def test_pp_layer_map_round_trip_queries(): diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e3a62e2e6ea9..6f1083e928c3 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -533,6 +533,12 @@ def _nixl_handshake( self.dst_num_blocks[self.engine_id] = self.num_blocks self.local_registered_layer_indices = [0] self.local_seen_layer_names = ["model.layers.0.self_attn"] + # register_kv_caches() also builds the layer-name -> NIXL region map that + # the handshake validation resolves producer members against; mirror it + # here since this mock bypasses register_kv_caches(). + self._local_layer_name_to_region_indices = { + name: [i] for i, name in enumerate(self.local_seen_layer_names) + } self._local_kv_cache_key = (0, self.tp_rank) self.kv_caches_base_addr[self.engine_id][self._local_kv_cache_key] = [0] @@ -1021,6 +1027,12 @@ def test_handshake_fails_on_kv_cache_layout_mismatch( worker.dst_num_blocks[worker.engine_id] = worker.num_blocks worker.local_registered_layer_indices = [0] worker.local_seen_layer_names = ["model.layers.0.self_attn"] + # register_kv_caches() builds the layer-name -> NIXL region map the + # handshake validation resolves producer members against; mirror it + # here since this test sets local registration state by hand. + worker._local_layer_name_to_region_indices = { + name: [i] for i, name in enumerate(worker.local_seen_layer_names) + } worker._local_kv_cache_key = (0, worker.tp_rank) worker.kv_caches_base_addr[worker.engine_id][worker._local_kv_cache_key] = [ 0 @@ -1091,6 +1103,12 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( worker.dst_num_blocks[worker.engine_id] = worker.num_blocks worker.local_registered_layer_indices = [0] worker.local_seen_layer_names = ["model.layers.0.self_attn"] + # register_kv_caches() builds the layer-name -> NIXL region map the + # handshake validation resolves producer members against; mirror it + # here since this test sets local registration state by hand. + worker._local_layer_name_to_region_indices = { + name: [i] for i, name in enumerate(worker.local_seen_layer_names) + } worker._local_kv_cache_key = (0, worker.tp_rank) worker.kv_caches_base_addr[worker.engine_id][worker._local_kv_cache_key] = [ 0 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py index f0bf59b54d6d..a7e99a55c102 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Metadata dataclasses and helpers for the NIXL connector.""" -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any from vllm.config import VllmConfig @@ -39,8 +39,15 @@ # registered_layer_names) and per-request pp_size. NIXL regions are # advertised per layer-name so HMA pool composition may differ across PP # producer shards and the decode consumer. +# 6: Add region_members so each advertised NIXL region declares ALL the +# (global_layer_index, kv_group_index) members sharing it — including HMA +# cross-group pooled members (e.g. an swa_cache pooled onto a later layer's +# main-attn region) that are dedup'd out of registered_layer_names. Without +# this, a pooled member belonging to a different kv group than the region's +# representative is never transferred (its blocks are dropped), corrupting +# KV under PP+HMA disaggregation. # -NIXL_CONNECTOR_VERSION: int = 5 +NIXL_CONNECTOR_VERSION: int = 6 @dataclass @@ -62,6 +69,19 @@ class NixlAgentMetadata: end_layer: int registered_layer_indices: list[int] registered_layer_names: list[str] + # Parallel to the advertised regions (registered_layer_names order): for + # each NIXL region, the full list of layer names whose transfer caches + # physically share that region. Captures HMA cross-group pooled members + # that registered_layer_names (representatives only) omits, so the transfer + # can cover every member's blocks in a shared region. Keyed by layer name + # (not (layer_index, kv_group_index)) because distinct caches can merge into + # one kv group via UniformTypeKVCacheSpecs (e.g. an MLA layer's main latent + # and its indexer k_cache both land in the full-attention group): a + # (layer_index, group) pair is then non-unique across regions, whereas the + # layer name uniquely identifies the region and is stable across the PP + # producer shard and the full-model consumer. Defaults to empty for backward + # construction; populated in register_kv_caches. + region_members: list[list[str]] = field(default_factory=list) @dataclass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 082ac707c5c9..238dc192e865 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -11,7 +11,8 @@ from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor -from typing import TYPE_CHECKING, Any, cast +from dataclasses import replace +from typing import TYPE_CHECKING, Any, TypeAlias, cast import msgspec import numpy as np @@ -87,6 +88,31 @@ logger = init_logger(__name__) +_ShardDescLayout: TypeAlias = tuple[ + int, # num_blocks + tuple[int, ...], # region_group_ids (kv-group id per descriptor region) + int, # physical_blocks_per_logical + int, # mamba_region_count + tuple[int, ...], # mamba_region_group_ids +] + + +def _make_shard_desc_layout( + num_blocks: int, + region_group_ids: tuple[int, ...], + *, + physical_blocks_per_logical: int = 1, + mamba_region_count: int = 0, + mamba_region_group_ids: tuple[int, ...] = (), +) -> _ShardDescLayout: + return ( + num_blocks, + region_group_ids, + physical_blocks_per_logical, + mamba_region_count, + mamba_region_group_ids, + ) + class NixlConnectorWorker: """Implementation of Worker side methods""" @@ -381,6 +407,15 @@ def __init__( EngineId, dict[tuple[int, int], list[int]] ](dict) self.local_seen_layer_names: list[str] = [] + # Map every local layer name (including pool members that share a + # NIXL region with another layer) to the region indices its caches + # occupy. Needed because HMA pooling lets producer and consumer pick + # different "representative" layer names for the same shared region, + # so strict-name matching against ``local_seen_layer_names`` alone + # misses sharing-partner names. + self._local_layer_name_to_region_indices: dict[str, list[int]] = defaultdict( + list + ) self._remote_agent_metadata: dict[ EngineId, dict[tuple[int, int], NixlAgentMetadata] ] = defaultdict(dict) @@ -403,11 +438,11 @@ def __init__( self.dst_xfer_side_handles = defaultdict[EngineId, dict[tuple[int, int], int]]( dict ) - # Per-shard descriptor layout: (num_blocks, region_group_ids) + # Per-shard descriptor layout: (num_blocks, region_group_ids, + # physical_blocks_per_logical, mamba_region_count, + # mamba_region_group_ids) # keyed by (engine_id, remote_pp_rank, "local" | "remote"). - self._xfer_desc_layouts: dict[ - tuple[EngineId, int, str], tuple[int, tuple[int, ...]] - ] = {} + self._xfer_desc_layouts: dict[tuple[EngineId, int, str], _ShardDescLayout] = {} # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. @@ -513,22 +548,27 @@ def _get_local_base_addresses(self) -> list[int]: def _local_region_indices_for_layer_names( self, registered_layer_names: list[str] ) -> list[int]: - local_names = self.local_seen_layer_names - positions_by_name: dict[str, list[int]] = defaultdict(list) - for local_idx, layer_name in enumerate(local_names): - positions_by_name[layer_name].append(local_idx) - + # ``_local_layer_name_to_region_indices`` covers every layer name + # present in the local kv_caches mapping — including pool members that + # were dedup'd out of ``local_seen_layer_names`` because they share a + # NIXL region with another (representative) layer. Strict matching on + # ``local_seen_layer_names`` would miss those names whenever the + # producer and consumer pick different pool representatives (e.g. + # producer's PP-rank-0 alone-SWA representative vs the full-model + # alone-SWA representative on the consumer). + mapping = self._local_layer_name_to_region_indices occurrences_by_name: dict[str, int] = defaultdict(int) local_indices: list[int] = [] for layer_name in registered_layer_names: occurrence = occurrences_by_name[layer_name] occurrences_by_name[layer_name] += 1 - matches = positions_by_name.get(layer_name, []) + matches = mapping.get(layer_name, []) if occurrence >= len(matches): raise RuntimeError( "NIXL handshake failed: producer registered layer " f"{layer_name!r} occurrence {occurrence} has no matching " - f"local region. Local registered layers: {local_names}" + f"local region. Local registered layers: " + f"{self.local_seen_layer_names}" ) local_indices.append(matches[occurrence]) return local_indices @@ -538,14 +578,78 @@ def _region_group_ids_for_layer_names( ) -> tuple[int, ...]: # Non-MLA split-K/V backends register K and V as separate regions # that share the same KV-group id, hence the x2 duplication. - group_ids = tuple( - self._layer_name_to_kv_group_index[name] for name in registered_layer_names - ) + group_ids = self._kv_group_indices_for_layer_names(registered_layer_names) assert self.transfer_topo is not None if self.transfer_topo.is_kv_layout_blocks_first: return tuple(g for g in group_ids for _ in range(2)) return group_ids + def _use_member_identity(self, nixl_agent_meta: NixlAgentMetadata) -> bool: + # Member-identity routing (B6): resolve each producer member to the + # consumer region that holds it, robust to HMA pool-representative + # divergence under PP. Applies on the plain non-blocks-first path + # (regions map 1:1 to the prepared dlist) for v6 producers. Blocks-first + # (virtual K/V split) and mamba (x4 expansion) keep the legacy path. + assert self.transfer_topo is not None + return ( + not self.transfer_topo.is_kv_layout_blocks_first + and not self._has_mamba + and bool(nixl_agent_meta.region_members) + ) + + def _expand_remote_members( + self, nixl_agent_meta: NixlAgentMetadata + ) -> tuple[list[int], tuple[int, ...], NixlAgentMetadata]: + # Expand a producer's region_members into one transfer unit per member. + # Returns (member_local_regions, member_groups, member_meta): + # - member_local_regions[k]: consumer NIXL region holding the k-th + # producer member, resolved by layer name. + # - member_groups[k]: that member's kv-group id (consumer side; matches + # the producer group since group ordering is validated at handshake). + # - member_meta: nixl_agent_meta with kv_caches_base_addr/block_lens + # expanded to one entry per member (repeating the producer region the + # member lives in), so the region-based builders emit one descriptor + # group per member without further changes. + member_local_regions: list[int] = [] + member_groups: list[int] = [] + member_remote_base: list[int] = [] + member_block_lens: list[int] = [] + for r, members in enumerate(nixl_agent_meta.region_members): + for layer_name in members: + local_region = self._member_to_local_region.get(layer_name) + if local_region is None: + raise RuntimeError( + "NIXL handshake failed: producer member " + f"{layer_name!r} has no matching local region." + ) + member_local_regions.append(local_region) + member_groups.append(self._layer_name_to_kv_group_index[layer_name]) + member_remote_base.append(nixl_agent_meta.kv_caches_base_addr[r]) + member_block_lens.append(nixl_agent_meta.block_lens[r]) + member_meta = replace( + nixl_agent_meta, + kv_caches_base_addr=member_remote_base, + block_lens=member_block_lens, + ) + return member_local_regions, tuple(member_groups), member_meta + + def _mamba_region_group_ids_for_layer_names( + self, registered_layer_names: list[str] + ) -> tuple[int, ...]: + group_ids = self._kv_group_indices_for_layer_names(registered_layer_names) + return tuple(g for g in group_ids for _ in range(4)) + + def _kv_group_indices_for_layer_names( + self, registered_layer_names: list[str] + ) -> tuple[int, ...]: + mapping = self._layer_name_to_kv_group_index + try: + return tuple(mapping[name] for name in registered_layer_names) + except KeyError as exc: + raise RuntimeError( + f"KV cache layer {exc.args[0]!r} is not present in any kv_cache_group." + ) from exc + def _try_update_pp_layer_map( self, engine_id: EngineId, remote_pp_size: int ) -> PPLayerMap | None: @@ -953,15 +1057,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): "cross-layer-blocks mode is not supported with " "pipeline_parallel_size > 1 yet." ) - if self._has_mamba and pp_size > 1: - # Per-shard descriptor layouts for hybrid (Mamba/SSM) producers - # need mamba_region_count / mamba_region_group_ids tracking that - # the consumer descriptor builder does not yet implement. Follow-up - # PR adds this. - raise RuntimeError( - "Hybrid (Mamba/SSM) models are not yet supported with " - "pipeline_parallel_size > 1 over NIXL PD disaggregation." - ) pp_rank = get_pp_group().rank_in_group start_layer, end_layer = self.model_config.get_layers_start_end_indices( self.vllm_config.parallel_config @@ -1008,6 +1103,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_len_per_layer = list[int]() seen_layer_indices: list[int] = [] seen_layer_names: list[str] = [] + # Reset before populating: register_kv_caches may run again after a + # sleep/wake cycle, and stale mappings would point at outdated regions. + self._local_layer_name_to_region_indices = defaultdict(list) + base_addr_to_region_idx: dict[int, int] = {} + # Parallel to seen_base_addresses: for each NIXL region, the full list + # of layer names whose transfer caches physically share it, including + # HMA cross-group pooled members dedup'd out of seen_layer_names. Drives + # per-member (all-group) transfer coverage. + region_members: list[list[str]] = [] for layer_name, cache_or_caches in xfer_buffers.items(): # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. @@ -1052,16 +1156,23 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # registering a single tensor for both K/V and splitting logically like FI. for cache in cache_list: base_addr = cache.data_ptr() - if base_addr in seen_base_addresses: + existing_region_idx = base_addr_to_region_idx.get(base_addr) + if existing_region_idx is not None: # NOTE (NickLucche) HMA employs memory pooling to share tensors # across groups. This results in skipping all tensors but the ones # pointed to by group0. Also, generally we will have more blocks # per tensor but fewer regions. logger.debug("Skipping %s because it's already seen", layer_name) + self._local_layer_name_to_region_indices[layer_name].append( + existing_region_idx + ) + region_members[existing_region_idx].append(layer_name) continue logger.debug( "Registering layer %s with cache shape: %s", layer_name, cache.shape ) + region_idx = len(seen_base_addresses) + base_addr_to_region_idx[base_addr] = region_idx seen_base_addresses.append(base_addr) # Only record non-Mamba page sizes. if isinstance(layer_spec, MambaSpec): @@ -1072,6 +1183,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_len_per_layer.append(physical_page_size) seen_layer_indices.append(layer_index) seen_layer_names.append(layer_name) + self._local_layer_name_to_region_indices[layer_name].append(region_idx) + region_members.append([layer_name]) if cache.shape[0] != num_blocks: raise AssertionError( @@ -1109,6 +1222,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert all(start_layer <= idx < end_layer for idx in seen_layer_indices) self._local_kv_cache_key = (pp_rank, self.tp_rank) self.local_seen_layer_names = seen_layer_names + self.region_members = region_members + # layer_name -> local NIXL region index. Each transfer-cache layer name + # maps to exactly one NIXL region, so this lets a consumer resolve each + # producer member to the region that physically holds it, independent of + # how the producer pooled it (member-identity routing — robust to HMA + # pool-representative divergence under PP). Keyed by layer name rather + # than (layer_index, kv_group_index): distinct caches can merge into one + # kv group via UniformTypeKVCacheSpecs (e.g. an MLA layer's main latent + # and its indexer k_cache both join the full-attention group), so the + # (layer, group) pair is non-unique across regions and would otherwise + # collapse those members onto one region (double-write + stale slot). + self._member_to_local_region: dict[str, int] = {} + for region_idx, members in enumerate(region_members): + for member in members: + self._member_to_local_region.setdefault(member, region_idx) self.kv_caches_base_addr[self.engine_id][self._local_kv_cache_key] = ( seen_base_addresses ) @@ -1186,6 +1314,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): end_layer=end_layer, registered_layer_indices=seen_layer_indices, registered_layer_names=seen_layer_names, + region_members=region_members, ) # Wrap metadata in payload with hash for defensive decoding assert self.compat_hash is not None @@ -1199,6 +1328,7 @@ def _build_mamba_local( self, base_addresses: list[int], block_size_ratio: int, + local_region_indices: list[int] | None = None, ) -> list[tuple[int, int, int]]: """Build 4 desc regions (x, B, C, ssm) per layer for local mamba blocks, enabling the 3-read transfer with DS conv layout.""" @@ -1213,7 +1343,9 @@ def _build_mamba_local( physical_per_logical = self._physical_blocks_per_logical_kv_block result: list[tuple[int, int, int]] = [] - for i, base_addr in enumerate(base_addresses): + if local_region_indices is None: + local_region_indices = list(range(len(base_addresses))) + for i, base_addr in zip(local_region_indices, base_addresses): # Jump one page_size, but ssm page_size may be bigger when kernel # locks block size to a specific value (physical_per_logical scale). page_stride = ( @@ -1373,6 +1505,7 @@ def register_local_xfer_handler( block_size: int, *, registered_layer_names: list[str] | None = None, + local_region_indices: list[int] | None = None, ) -> tuple[int, list[tuple[int, int, int]]]: """ Function used for register local xfer handler with local block_size or @@ -1387,10 +1520,17 @@ def register_local_xfer_handler( """ assert self.transfer_topo is not None block_size_ratio = self.block_size // block_size - # PP-aware: when registered_layer_names is provided, register only the - # local regions matching the producer shard's layers (in the producer's - # advertised order). Otherwise register all local regions. - if registered_layer_names is None: + # PP-aware region selection (in the producer's advertised order): + # - local_region_indices given (member-identity, B6): one entry per + # producer member, already resolved to the consumer region holding it. + # - registered_layer_names given: one entry per producer region + # (representative). Otherwise register all local regions. + if local_region_indices is not None: + local_base_addresses_all = self._get_local_base_addresses() + local_base_addresses = [ + local_base_addresses_all[i] for i in local_region_indices + ] + elif registered_layer_names is None: local_base_addresses = self._get_local_base_addresses() local_region_indices = list(range(len(local_base_addresses))) else: @@ -1413,7 +1553,8 @@ def register_local_xfer_handler( self.device_id, ) if self._has_mamba: - assert self.num_descs == len(blocks_data) + if registered_layer_names is None: + assert self.num_descs == len(blocks_data) # TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the 3-descs split # is unnecessary — a single conv desc per block suffices. Consider # adding a fast path that falls back to the standard 2-region @@ -1422,7 +1563,9 @@ def register_local_xfer_handler( # because local descs are created before knowing the remote TP. logger.debug("Registering local Mamba descriptors (4 regions/layer)") blocks_data.extend( - self._build_mamba_local(local_base_addresses, block_size_ratio) + self._build_mamba_local( + local_base_addresses, block_size_ratio, local_region_indices + ) ) descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) @@ -1570,6 +1713,20 @@ def add_remote_agent( plan = self.tp_mappings[plan_key] + # Member-identity expansion (B6): one transfer unit per producer member, + # each routed by layer name to the consumer region holding it. Covers + # HMA cross-group pooled members (e.g. swa) regardless of which side's + # pool representative they were dedup'd under. member_meta falls back to + # nixl_agent_meta unchanged when member-identity does not apply. + member_identity = self._use_member_identity(nixl_agent_meta) + member_local_regions: list[int] = [] + member_groups: tuple[int, ...] = () + member_meta = nixl_agent_meta + if member_identity: + member_local_regions, member_groups, member_meta = ( + self._expand_remote_members(nixl_agent_meta) + ) + # PP-aware: lazily register a local xfer handler for this producer # shard's layers + block size (idempotent across (engine, pp_rank, # block_size)). Replaces main's eager call in register_kv_caches plus @@ -1579,21 +1736,43 @@ def add_remote_agent( handle, blocks_data = self.register_local_xfer_handler( nixl_agent_meta.block_size, registered_layer_names=nixl_agent_meta.registered_layer_names, + local_region_indices=member_local_regions if member_identity else None, ) self.src_xfer_handles_by_remote[local_handle_key] = handle self.src_blocks_data_by_remote[local_handle_key] = blocks_data self._xfer_desc_layouts[(engine_id, remote_pp_rank, "local")] = ( - self.num_blocks * block_size_ratio, - self._region_group_ids_for_layer_names( - nixl_agent_meta.registered_layer_names - ), + _make_shard_desc_layout( + self.num_blocks * block_size_ratio, + member_groups + if member_identity + else self._region_group_ids_for_layer_names( + nixl_agent_meta.registered_layer_names + ), + physical_blocks_per_logical=( + self._physical_blocks_per_logical_kv_block + ), + mamba_region_count=len(nixl_agent_meta.registered_layer_names) * 4 + if self._has_mamba + else 0, + mamba_region_group_ids=( + self._mamba_region_group_ids_for_layer_names( + nixl_agent_meta.registered_layer_names + ) + if self._has_mamba + else () + ), + ) ) src_blocks_data = self.src_blocks_data_by_remote[local_handle_key] - local_num_blocks, local_region_group_ids = self._xfer_desc_layouts[ + local_num_blocks, local_region_group_ids, _, _, _ = self._xfer_desc_layouts[ (engine_id, remote_pp_rank, "local") ] - local_region_indices = self._local_region_indices_for_layer_names( - nixl_agent_meta.registered_layer_names + local_region_indices = ( + member_local_regions + if member_identity + else self._local_region_indices_for_layer_names( + nixl_agent_meta.registered_layer_names + ) ) ### (Optional) Register local agent memory regions. MLA is not split. @@ -1628,7 +1807,7 @@ def add_remote_agent( # Register all remote blocks, but only the corresponding kv heads. blocks_data = self._build_fa_remote( plan, - nixl_agent_meta, + member_meta, block_size_ratio, local_region_indices, ) @@ -1662,10 +1841,25 @@ def add_remote_agent( ) self._xfer_desc_layouts[(engine_id, remote_pp_rank, "remote")] = ( - nixl_agent_meta.num_blocks, - self._region_group_ids_for_layer_names( - nixl_agent_meta.registered_layer_names - ), + _make_shard_desc_layout( + nixl_agent_meta.num_blocks, + member_groups + if member_identity + else self._region_group_ids_for_layer_names( + nixl_agent_meta.registered_layer_names + ), + physical_blocks_per_logical=physical_blocks_per_logical, + mamba_region_count=len(nixl_agent_meta.kv_caches_base_addr) * 4 + if self._has_mamba + else 0, + mamba_region_group_ids=( + self._mamba_region_group_ids_for_layer_names( + nixl_agent_meta.registered_layer_names + ) + if self._has_mamba + else () + ), + ) ) self._remote_agents[engine_id][shard_key] = remote_agent_name @@ -1764,9 +1958,10 @@ def _validate_remote_agent_handshake( "Remote is HND and local is NHD, enabled additional permute " "on local device KV." ) - assert not self._is_hma_required, ( - "HMA does not support block size post processing" - ) + if self._is_hma_required: + raise RuntimeError( + "HMA does not support block size post processing" + ) self.enable_permute_local_kv = True else: raise RuntimeError( @@ -2044,7 +2239,10 @@ def get_finished(self) -> tuple[set[str], set[str]]: if not self.use_mla and ( block_size_ratio > 1 or self.enable_permute_local_kv ): - assert not self._is_hma_required + if self._is_hma_required: + raise RuntimeError( + "HMA does not support block size post processing" + ) block_ids_for_blocksize_post_process[block_size_ratio].append( meta.local_physical_block_ids[0] ) @@ -2446,30 +2644,28 @@ def _read_blocks( remote_info.remote_block_size ) if block_size_ratio > 1: - # TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups. - assert not self._is_hma_required - local_block_ids0 = local_block_ids[0] if local_block_ids else [] - remote_block_ids0 = remote_block_ids[0] - local_block_ids_mapped = self.get_mapped_blocks( - np.asarray(local_block_ids0), block_size_ratio - ).tolist() - if len(local_block_ids_mapped) > len(remote_block_ids0): - # NOTE: - # get_mapped_blocks will always expand block_ids for n times. - # ex: - # prefill block_ids with block_size as 4: - # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - # Local decode block_ids with block_size as 16: [1, 2, 3] - # expanded decode block_ids with get_mapped_blocks from [1, 2, 3] to - # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] - # Then we clip local to align with prefill - # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to - # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - local_block_ids_mapped = local_block_ids_mapped[ - : len(remote_block_ids0) - ] - local_block_ids = [local_block_ids_mapped] if local_block_ids_mapped else [] - remote_block_ids = [remote_block_ids0] + local_block_ids_mapped_by_group: list[list[int]] = [] + remote_block_ids_by_group: list[list[int]] = [] + for group_idx, remote_group_ids in enumerate(remote_block_ids): + remote_group_ids = list(remote_group_ids) + local_group_ids = list(local_block_ids[group_idx]) + if _is_ssm_spec(self._group_spec_types[group_idx]): + local_block_ids_mapped = local_group_ids + else: + local_block_ids_mapped = self.get_mapped_blocks( + np.asarray(local_group_ids), block_size_ratio + ).tolist() + if len(local_block_ids_mapped) > len(remote_group_ids): + # get_mapped_blocks expands each local block by + # block_size_ratio. Clip padding to the producer's + # actual remote block count. + local_block_ids_mapped = local_block_ids_mapped[ + : len(remote_group_ids) + ] + local_block_ids_mapped_by_group.append(local_block_ids_mapped) + remote_block_ids_by_group.append(remote_group_ids) + local_block_ids = local_block_ids_mapped_by_group + remote_block_ids = remote_block_ids_by_group # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). @@ -2576,18 +2772,53 @@ def _get_block_descs_ids_for_shard( block_ids: BlockIds, ) -> np.ndarray: """Get descriptor IDs relative to a shard-local prepared dlist.""" - num_blocks, region_group_ids = self._xfer_desc_layouts[ - (engine_id, remote_pp_rank, side) - ] + ( + num_blocks, + region_group_ids, + physical_blocks_per_logical, + mamba_region_count, + mamba_region_group_ids, + ) = self._xfer_desc_layouts[(engine_id, remote_pp_rank, side)] + desc_ids = [] for region_id, group_id in enumerate(region_group_ids): - group_arr = np.asarray(block_ids[group_id], dtype=np.int64) - if group_arr.size == 0: + if self._has_mamba and _is_ssm_spec(self._group_spec_types[group_id]): continue - desc_ids.append(region_id * num_blocks + group_arr) + group_arr = np.asarray(block_ids[group_id], dtype=np.int64) + if group_arr.size > 0: + desc_ids.append(region_id * num_blocks + group_arr) + + if self._has_mamba: + assert physical_blocks_per_logical > 0 + assert num_blocks % physical_blocks_per_logical == 0, ( + "Mamba descriptor layout num_blocks must be divisible by " + "physical_blocks_per_logical" + ) + assert mamba_region_count > 0 + assert len(mamba_region_group_ids) == mamba_region_count + logical_blocks = num_blocks // physical_blocks_per_logical + num_fa_descs = len(region_group_ids) * num_blocks + for region_id, group_id in enumerate(mamba_region_group_ids): + if not _is_ssm_spec(self._group_spec_types[group_id]): + continue + group_arr = np.asarray(block_ids[group_id], dtype=np.int64) + if group_arr.size > 0: + desc_ids.append( + num_fa_descs + region_id * logical_blocks + group_arr + ) + if not desc_ids: return np.empty(0, dtype=np.int64) - return np.concatenate(desc_ids) + desc_ids_arr = np.concatenate(desc_ids) + if self._has_mamba: + num_descs = len(region_group_ids) * num_blocks + mamba_region_count * ( + num_blocks // physical_blocks_per_logical + ) + assert int(desc_ids_arr.max()) < num_descs, ( + "Mamba shard descriptor IDs must be relative to the shard-local " + "dlist layout" + ) + return desc_ids_arr def get_mapped_blocks( self, block_ids: np.ndarray, block_size_ratio: int