diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 6d4e6565e373..c388419c6846 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -141,9 +141,6 @@ def test_read_blocks_for_req_expands_remote_ids( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import ( NixlConnectorMetadata, ) - from vllm.distributed.kv_transfer.kv_connector.v1.nixl.tp_mapping import ( - TPMapping, - ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( NixlConnectorWorker, ) @@ -172,7 +169,7 @@ def test_read_blocks_for_req_expands_remote_ids( remote_engine_id = "remote-engine" worker.transfer_topo = MagicMock() - # tp_ratio not exercised (all_source_ranks is empty so no reads run), + # tp_ratio not exercised (remote_ranks is empty so no reads run), # but set for realism. worker.transfer_topo.tp_ratio.return_value = tp_ratio remote_info = MagicMock() @@ -180,10 +177,10 @@ def test_read_blocks_for_req_expands_remote_ids( worker.transfer_topo.get_engine_info.return_value = remote_info worker.use_mla = False - mock_plan = MagicMock(spec=TPMapping) - mock_plan.all_source_ranks = () - mock_plan.source_ranks_per_group = () - worker.tp_mappings = {remote_engine_id: mock_plan} + # Empty tp_mappings: no source ranks so no reads are issued. + num_groups = len(resolved_types) + worker.tp_mappings = {remote_engine_id: tuple({} for _ in range(num_groups))} + worker.remote_ranks = {remote_engine_id: ()} metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( @@ -346,9 +343,6 @@ def test_mismatched_physical_per_logical_fails_with_prefix_caching( mamba_enabled=True, ) worker._has_mamba = True - worker._group_spec_types = tuple( - type(g.kv_cache_spec) for g in worker.kv_cache_config.kv_cache_groups - ) local_block_ids = (local_fa_blocks, ssm_blocks) remote_block_ids = (remote_fa_blocks, ssm_blocks) diff --git a/tests/v1/kv_connector/unit/test_tp_mapping.py b/tests/v1/kv_connector/unit/test_tp_mapping.py index 95d49faf042f..08e0cd89b82a 100644 --- a/tests/v1/kv_connector/unit/test_tp_mapping.py +++ b/tests/v1/kv_connector/unit/test_tp_mapping.py @@ -9,62 +9,99 @@ from __future__ import annotations -from types import SimpleNamespace +from unittest.mock import MagicMock import pytest +import torch -from vllm.distributed.kv_transfer.kv_connector.v1.nixl.tp_mapping import ( - TPMapping, - compute_tp_mapping, -) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( NixlConnectorWorker, ) -from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + MambaSpec, + ShardRange, + TPTransferSlice, +) # ====================================================================== # Test fixtures / helpers # ====================================================================== -def _compute_mapping( +def _make_fa_spec(num_kv_heads: int = 4): + return FullAttentionSpec( + block_size=16, + num_kv_heads=num_kv_heads, + head_size=128, + head_size_v=128, + dtype=torch.float16, + ) + + +def _get_slices( tp_rank: int = 0, tp_size: int = 1, remote_tp_size: int = 1, - is_mla: bool = False, - num_kv_heads: int = 8, - group_spec_types: tuple[type, ...] = (FullAttentionSpec,), -) -> TPMapping: - transfer_topology = SimpleNamespace( - tp_rank=tp_rank, - tp_size=tp_size, - is_mla=is_mla, - total_num_kv_heads=num_kv_heads, - ) - return compute_tp_mapping( - transfer_topology=transfer_topology, - remote_tp_size=remote_tp_size, - group_spec_types=group_spec_types, + total_num_kv_heads: int = 8, + spec=None, +) -> dict[int, TPTransferSlice]: + """Call get_tp_transfer_slices on the given spec (or a default FA spec).""" + if spec is None: + num_kv_heads = max(1, total_num_kv_heads // tp_size) + spec = _make_fa_spec(num_kv_heads) + return spec.get_tp_transfer_slices( + tp_rank, tp_size, remote_tp_size, total_num_kv_heads ) +def _remote_ranks_from_slices( + *group_slices: dict[int, TPTransferSlice], +) -> tuple[int, ...]: + """Derive deduplicated sorted source ranks from multiple group slices.""" + return tuple(sorted({r for slices in group_slices for r in slices})) + + # ====================================================================== # TP mapping structure tests # ====================================================================== class TestTPMappingStructure: - def test_source_ranks_homogeneous(self): - m = _compute_mapping(tp_size=2, tp_rank=1, remote_tp_size=2) - assert m.all_source_ranks == (1,) + def test_remote_ranks_homogeneous(self): + slices = _get_slices(tp_size=2, tp_rank=1, remote_tp_size=2) + assert _remote_ranks_from_slices(slices) == (1,) - def test_source_ranks_d_gt_p(self): - m = _compute_mapping(tp_size=4, tp_rank=2, remote_tp_size=2) - assert m.all_source_ranks == (1,) + def test_remote_ranks_d_gt_p(self): + slices = _get_slices(tp_size=4, tp_rank=2, remote_tp_size=2) + assert _remote_ranks_from_slices(slices) == (1,) - def test_source_ranks_p_gt_d(self): - m = _compute_mapping(tp_size=1, tp_rank=0, remote_tp_size=2) - assert m.all_source_ranks == (0, 1) + def test_remote_ranks_p_gt_d(self): + slices = _get_slices(tp_size=1, tp_rank=0, remote_tp_size=2) + assert _remote_ranks_from_slices(slices) == (0, 1) + + def test_per_group_slices(self): + slices = _get_slices(tp_size=2, tp_rank=0, remote_tp_size=4) + assert len(slices) == 2 + assert 0 in slices + assert 1 in slices + + def test_has_rank_in_group(self): + slices = _get_slices(tp_size=1, tp_rank=0, remote_tp_size=2) + assert 0 in slices + assert 1 in slices + assert 2 not in slices + + def test_gqa_dedup_load_balanced(self): + """With total_heads=2, remote_tp=4: picks aligned remote ranks.""" + slices_r0 = _get_slices( + tp_size=2, tp_rank=0, remote_tp_size=4, total_num_kv_heads=2 + ) + slices_r1 = _get_slices( + tp_size=2, tp_rank=1, remote_tp_size=4, total_num_kv_heads=2 + ) + assert 0 in slices_r0 + assert 2 in slices_r1 # ====================================================================== @@ -72,34 +109,53 @@ def test_source_ranks_p_gt_d(self): # ====================================================================== -def _make_mock_worker_for_splits(group_spec_types): - """Build a mock NixlConnectorWorker with _group_spec_types for split tests.""" +def _make_mock_worker_for_splits( + group_specs: list, + tp_mappings: tuple, + remote_ranks: tuple[int, ...], + engine_id: str = "remote_0", +): + """Build a mock NixlConnectorWorker with the fields _build_local_splits needs.""" worker = object.__new__(NixlConnectorWorker) - worker._group_spec_types = group_spec_types + kv_cache_groups = [] + for spec in group_specs: + group = MagicMock() + group.kv_cache_spec = spec + kv_cache_groups.append(group) + kv_cache_config = MagicMock() + kv_cache_config.kv_cache_groups = kv_cache_groups + worker.kv_cache_config = kv_cache_config + worker.tp_mappings = {engine_id: tp_mappings} + worker.remote_ranks = {engine_id: remote_ranks} + worker.transfer_topo = MagicMock() return worker class TestBuildSrcSplitHandles: @pytest.mark.parametrize("remote_tp_size", [2, 4]) - def test_build_src_split_handles(self, remote_tp_size): + def test_split_shape(self, remote_tp_size): + """Each split has correct number of descs with correct chunk size.""" tp_rank = 0 tp_size = 1 + total_num_kv_heads = 8 + engine_id = "remote_0" - plan = _compute_mapping( - tp_rank=tp_rank, - tp_size=tp_size, - remote_tp_size=remote_tp_size, + fa_spec = _make_fa_spec(num_kv_heads=total_num_kv_heads // tp_size) + fa_slices = fa_spec.get_tp_transfer_slices( + tp_rank, tp_size, remote_tp_size, total_num_kv_heads ) + remote_ranks = _remote_ranks_from_slices(fa_slices) - worker = _make_mock_worker_for_splits((FullAttentionSpec,)) + worker = _make_mock_worker_for_splits( + group_specs=[fa_spec], + tp_mappings=(fa_slices,), + remote_ranks=remote_ranks, + engine_id=engine_id, + ) src_blocks_data = [(0x2000 + i * 1024, 1024, 0) for i in range(8)] - num_descs = len(src_blocks_data) + num_fa_descs = len(src_blocks_data) splits = list( - worker._build_local_splits_from_plan( - plan, - src_blocks_data, - num_descs, - ) + worker._build_local_splits(engine_id, src_blocks_data, num_fa_descs) ) assert len(splits) == remote_tp_size @@ -108,22 +164,94 @@ def test_build_src_split_handles(self, remote_tp_size): for _, length, _ in handle: assert length == 1024 // remote_tp_size + @pytest.mark.parametrize( + "remote_tp_size,total_num_kv_heads", + [(2, 4), (2, 8), (4, 8)], + ) + def test_fa_offsets_p_gt_d(self, remote_tp_size, total_num_kv_heads): + """Verify concrete FA offsets for multi-head P>D (the previously buggy path). + + With local_tp=1, the full local block covers all heads. Each remote + rank's slice should land at the correct byte offset proportional to + its position in the local tensor. + """ + tp_rank = 0 + tp_size = 1 + engine_id = "remote_0" + local_block_len = 1024 + + fa_spec = _make_fa_spec(num_kv_heads=total_num_kv_heads // tp_size) + fa_slices = fa_spec.get_tp_transfer_slices( + tp_rank, tp_size, remote_tp_size, total_num_kv_heads + ) + remote_ranks = _remote_ranks_from_slices(fa_slices) + + worker = _make_mock_worker_for_splits( + group_specs=[fa_spec], + tp_mappings=(fa_slices,), + remote_ranks=remote_ranks, + engine_id=engine_id, + ) + base_addr = 0x4000 + src_blocks_data = [(base_addr, local_block_len, 0)] + splits = list(worker._build_local_splits(engine_id, src_blocks_data, 1)) + + assert len(splits) == remote_tp_size + chunk = local_block_len // remote_tp_size + for idx, (rank, sl) in enumerate(sorted(fa_slices.items())): + expected_offset = ( + sl.local_write_offset * local_block_len // len(sl.local_shard) + ) + # Offsets should tile the local block without overlap + assert expected_offset == idx * chunk + addr, length, dev = splits[idx][0] + assert addr == base_addr + expected_offset + assert length == chunk + assert dev == 0 + class TestMambaPlanSplitHandles: """Verify split handles for Mamba with FA/SSM distinction.""" def test_fa_and_ssm_different_split_factors(self): """Section 0 split by num_attn_reads, section 1 by abs_tp.""" - fa_readers = (0,) - ssm_readers = (0, 1) - plan = TPMapping( - source_ranks_per_group=(fa_readers, ssm_readers), - all_source_ranks=(0, 1), - rank_to_attention_slot={0: 0, 1: 0}, - rank_offset_factor=0, + engine_id = "remote_0" + # total_kv_heads=1 < remote_tp=2 triggers GQA dedup: + # only remote rank 0 holds unique FA data. + total_num_kv_heads = 1 + + fa_spec = _make_fa_spec(num_kv_heads=1) + mamba_spec = MagicMock(spec=MambaSpec) + + # local_tp=1, remote_tp=2 + # FA: 1 unique slice (reads from remote 0, GQA dedup skips rank 1) + # Mamba: 2 slices (reads from remote 0 and 1) + fa_slices = fa_spec.get_tp_transfer_slices(0, 1, 2, total_num_kv_heads) + + shard_mamba = ShardRange(0, 1, 1) + ssm_slices = { + 0: TPTransferSlice( + remote_rank=0, + remote_shard=shard_mamba, + local_shard=shard_mamba, + transfer_range=shard_mamba, + ), + 1: TPTransferSlice( + remote_rank=1, + remote_shard=shard_mamba, + local_shard=shard_mamba, + transfer_range=shard_mamba, + ), + } + remote_ranks = _remote_ranks_from_slices(fa_slices, ssm_slices) + + worker = _make_mock_worker_for_splits( + group_specs=[fa_spec, mamba_spec], + tp_mappings=(fa_slices, ssm_slices), + remote_ranks=remote_ranks, + engine_id=engine_id, ) - worker = _make_mock_worker_for_splits((FullAttentionSpec, MambaSpec)) # 2 FA descs + 1 SSM desc src_blocks_data = [ (1000, 200, 0), # FA desc 0 @@ -131,16 +259,27 @@ def test_fa_and_ssm_different_split_factors(self): (3000, 400, 0), # SSM desc 0 ] - splits = list(worker._build_local_splits_from_plan(plan, src_blocks_data, 2)) + splits = list(worker._build_local_splits(engine_id, src_blocks_data, 2)) assert len(splits) == 2 # 2 source ranks - # Rank 0 (FA source, p_idx=0): - # FA: chunk=200//1=200, slot=0 → (1000, 200, 0), (2000, 200, 0) - # SSM: chunk=400//2=200, idx=0 → (3000, 200, 0) - assert splits[0] == [(1000, 200, 0), (2000, 200, 0), (3000, 200, 0)] + # Rank 0 is in fa_slices -> uses local_write_offset for FA offset + fa_chunk = 200 // len(fa_slices) + ssm_chunk = 400 // len(ssm_slices) + + # Rank 0 (remote_idx=0): + # FA: chunk=200//1=200 (only 1 FA slice) + # offset = local_write_offset * local_block_len // len(local_shard) + # SSM: chunk=400//2=200, offset = remote_idx(0) * 200 + sl = fa_slices[0] + fa_offset_r0 = sl.local_write_offset * 200 // len(sl.local_shard) + assert splits[0][0] == (1000 + fa_offset_r0, fa_chunk, 0) + assert splits[0][1] == (2000 + fa_offset_r0, fa_chunk, 0) + assert splits[0][2] == (3000 + 0 * ssm_chunk, ssm_chunk, 0) - # Rank 1 (not FA source, p_idx=1): - # FA: chunk=200//1=200, slot=0 (skip_fa) → (1000, 200, 0), (2000, 200, 0) - # SSM: chunk=400//2=200, idx=1 → (3200, 200, 0) - assert splits[1] == [(1000, 200, 0), (2000, 200, 0), (3200, 200, 0)] + # Rank 1 (remote_idx=1): + # FA: rank 1 NOT in fa_slices -> GQA-deduped placeholder (addr, chunk, dev) + # SSM: chunk=400//2=200, offset = remote_idx(1) * 200 + assert splits[1][0] == (1000, fa_chunk, 0) + assert splits[1][1] == (2000, fa_chunk, 0) + assert splits[1][2] == (3000 + 1 * ssm_chunk, ssm_chunk, 0) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/tp_mapping.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/tp_mapping.py deleted file mode 100644 index b034b7605087..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/tp_mapping.py +++ /dev/null @@ -1,142 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""TP mapping computation for NIXL KV cache transfers.""" - -from __future__ import annotations - -from dataclasses import dataclass - -import numpy as np - -from vllm.distributed.kv_transfer.kv_connector.utils import ( - BlockIds, - TransferTopology, -) -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheSpec, MambaSpec - -# ====================================================================== -# Data structures -# ====================================================================== - - -@dataclass(frozen=True) -class ReadSpec: - """Specification for a single remote block read operation.""" - - remote_rank: int - local_block_ids: BlockIds - remote_block_ids: BlockIds - - -def _is_attention_spec(spec_type: type[KVCacheSpec]) -> bool: - return issubclass(spec_type, AttentionSpec) - - -def _is_ssm_spec(spec_type: type[KVCacheSpec]) -> bool: - return issubclass(spec_type, MambaSpec) - - -@dataclass(frozen=True) -class TPMapping: - """Complete local-to-remote TP mapping for one remote engine. - - Generated once per remote engine during handshake. - """ - - # Remote TP ranks that this local rank reads from, per group. - # Position = local piece index. - source_ranks_per_group: tuple[tuple[int, ...], ...] - - # Superset of all source ranks (union of all groups). - all_source_ranks: tuple[int, ...] - - # Maps each source rank to its FA head slot index. - rank_to_attention_slot: dict[int, int] - - # FA head offset factor for hetero-TP (D_TP > P_TP). - rank_offset_factor: int - - -# ====================================================================== -# TP mapping computation -# ====================================================================== - - -def compute_tp_mapping( - transfer_topology: TransferTopology, - remote_tp_size: int, - group_spec_types: tuple[type[KVCacheSpec], ...], -) -> TPMapping: - """Build the complete local-to-remote TP mapping. - - Computes source ranks, head slot assignments, and the rank offset - factor in a single pass. - """ - tp_rank = transfer_topology.tp_rank - tp_size = transfer_topology.tp_size - total_num_kv_heads = transfer_topology.total_num_kv_heads - # --- Attention source ranks --- - if transfer_topology.is_mla or tp_size >= remote_tp_size: - # D (local TP) > P (remote TP): multiple local ranks read different chunks from - # *one* remote rank, corresponding to different kv heads. - # For MLA, we only need one remote since cache is duplicated. When P TP=k*TP k, - # this will spread mla ranks to read from remote k*tp_rank. - attn_ranks = [tp_rank * remote_tp_size // tp_size] - else: - # P (remote TP) > D (local TP): one local rank - # reads from multiple remote ranks. - # GQA dedup: when K < remote_tp_size, several remote ranks - # hold the same KV head. np.unique keeps only the first - # rank per unique head so we don't issue redundant reads. - abs_tp = remote_tp_size // tp_size - start = tp_rank * abs_tp - heads = np.arange(start, start + abs_tp) * total_num_kv_heads // remote_tp_size - _, unique_idx = np.unique(heads, return_index=True) - attn_ranks = (start + np.sort(unique_idx)).tolist() - - # --- SSM source ranks --- - has_ssm = any(_is_ssm_spec(t) for t in group_spec_types) - if has_ssm: - if tp_size < remote_tp_size: - abs_tp = remote_tp_size // tp_size - ssm_ranks = list(range(tp_rank * abs_tp, (tp_rank + 1) * abs_tp)) - else: - ssm_ranks = list(attn_ranks) - else: - ssm_ranks = [] - - all_ranks = sorted(set(attn_ranks) | set(ssm_ranks)) - - # --- Per-group ordered source ranks --- - source_ranks_per_group = tuple( - tuple(ssm_ranks) if _is_ssm_spec(t) else tuple(attn_ranks) - for t in group_spec_types - ) - - # --- Attention head slots --- - head_to_slot: dict[int, int] = {} - for i, r in enumerate(attn_ranks): - head_to_slot[r * total_num_kv_heads // remote_tp_size] = i - rank_to_attention_slot = { - r: head_to_slot.get(r * total_num_kv_heads // remote_tp_size, 0) - for r in all_ranks - } - - # --- Rank offset factor --- - if transfer_topology.is_mla or tp_size <= remote_tp_size: - # We don't index into remote for reading, no offset needed. - rank_offset_factor = 0 - elif tp_size > total_num_kv_heads: - local_head = tp_rank * total_num_kv_heads // tp_size - p_start = attn_ranks[0] * total_num_kv_heads // remote_tp_size - rank_offset_factor = local_head - p_start - else: - # D TP > P TP: we index into remote to read different heads depending on rank. - rank_offset_factor = tp_rank % (tp_size // remote_tp_size) - - return TPMapping( - source_ranks_per_group=source_ranks_per_group, - all_source_ranks=tuple(all_ranks), - rank_to_attention_slot=rank_to_attention_slot, - rank_offset_factor=rank_offset_factor, - ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/utils.py index 2fa3829eaecb..d0b72464a27b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/utils.py @@ -10,7 +10,6 @@ from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_socket -from vllm.v1.kv_cache_interface import KVCacheSpec, UniformTypeKVCacheSpecs # Supported platforms and types of kv transfer buffer. # {device: tuple of supported kv buffer types} @@ -47,11 +46,3 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: finally: if ctx is not None: ctx.destroy(linger=0) - - -def get_representative_spec_type(spec: KVCacheSpec) -> type[KVCacheSpec]: - if isinstance(spec, UniformTypeKVCacheSpecs): - # All inner specs are the same type; pick any. - inner = next(iter(spec.kv_cache_specs.values())) - return type(inner) - return type(spec) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 0d30d4a692ad..8afcb1ea82fe 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -11,6 +11,7 @@ from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast import msgspec @@ -43,16 +44,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl.stats import ( NixlKVConnectorStats, ) -from vllm.distributed.kv_transfer.kv_connector.v1.nixl.tp_mapping import ( - ReadSpec, - TPMapping, - _is_attention_spec, - _is_ssm_spec, - compute_tp_mapping, -) from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import ( _NIXL_SUPPORTED_DEVICE, - get_representative_spec_type, zmq_ctx, ) from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( @@ -69,8 +62,10 @@ from vllm.utils.network_utils import make_zmq_path from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.kv_cache_interface import ( + AttentionSpec, FullAttentionSpec, MambaSpec, + TPTransferSlice, UniformTypeKVCacheSpecs, ) from vllm.v1.worker.block_table import BlockTable @@ -78,14 +73,29 @@ if TYPE_CHECKING: from vllm.config import VllmConfig - from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec logger = init_logger(__name__) +@dataclass(frozen=True) +class ReadSpec: + """Specification for a single remote block read operation.""" + + remote_rank: int + local_block_ids: BlockIds + remote_block_ids: BlockIds + + class NixlConnectorWorker: """Implementation of Worker side methods""" + def _get_representative_spec(self, group) -> "KVCacheSpec": + spec = group.kv_cache_spec + if isinstance(spec, UniformTypeKVCacheSpecs): + return next(iter(spec.kv_cache_specs.values())) + return spec + def _compute_desc_ids( self, block_ids: BlockIds, @@ -121,12 +131,15 @@ def _compute_desc_ids( all_descs: list[np.ndarray] = [] for i, group in enumerate(block_ids): group_arr = np.asarray(group) - if _is_attention_spec(self._group_spec_types[i]): + spec = self._get_representative_spec( + self.kv_cache_config.kv_cache_groups[i] + ) + if isinstance(spec, AttentionSpec): fa_region_ids = np.arange(num_fa_regions)[:, None] all_descs.append( (fa_region_ids * num_blocks + group_arr[None, :]).flatten() ) - elif _is_ssm_spec(self._group_spec_types[i]): + elif isinstance(spec, MambaSpec): # NOTE (NickLucche) SSM and Attention block regions can # be exchanged arbitrarily by manager. Therefore, descs # are laid out as: @@ -143,52 +156,70 @@ def _compute_desc_ids( ).flatten() ) else: - raise ValueError( - f"Unknown spec type {self._group_spec_types[i]} at index {i}" - ) + raise ValueError(f"Unknown spec type {type(spec)} at index {i}") return np.concatenate(all_descs) - def _build_local_splits_from_plan( + def _build_local_splits( self, - plan: TPMapping, + engine_id: EngineId, src_blocks_data: list[tuple[int, int, int]], num_fa_descs: int, ) -> Iterator[list[tuple[int, int, int]]]: """Build split handle data for P_TP > D_TP scenario. num_fa_descs is the boundary between FA and SSM descriptors. - Split counts are derived from source_ranks_per_group lengths. - FA uses rank_to_attention_slot for the slot offset; - SSM uses the rank's positional index. + Split counts are derived from per-group slice counts. """ - fa_idx = next( - i for i, t in enumerate(self._group_spec_types) if _is_attention_spec(t) + assert self.transfer_topo is not None + + fa_group_idx = next( + i + for i, group in enumerate(self.kv_cache_config.kv_cache_groups) + if isinstance(self._get_representative_spec(group), AttentionSpec) ) - fa_num_splits = len(plan.source_ranks_per_group[fa_idx]) + fa_slices = self.tp_mappings[engine_id][fa_group_idx] - has_ssm_descs = num_fa_descs < len(src_blocks_data) ssm_idx = next( - (i for i, t in enumerate(self._group_spec_types) if _is_ssm_spec(t)), + ( + i + for i, group in enumerate(self.kv_cache_config.kv_cache_groups) + if isinstance(self._get_representative_spec(group), MambaSpec) + ), None, ) ssm_num_splits = ( - len(plan.source_ranks_per_group[ssm_idx]) - if has_ssm_descs and ssm_idx is not None + len(self.tp_mappings[engine_id][ssm_idx]) + if num_fa_descs < len(src_blocks_data) and ssm_idx is not None else 0 ) - for p_idx, p_rank in enumerate(plan.all_source_ranks): - fa_slot = plan.rank_to_attention_slot.get(p_rank, 0) - + for remote_idx, remote_rank in enumerate(self.remote_ranks[engine_id]): handle: list[tuple[int, int, int]] = [] - for j, (addr, local_len, dev) in enumerate(src_blocks_data): + for j, (addr, local_block_len, dev) in enumerate(src_blocks_data): if j < num_fa_descs: - chunk = local_len // fa_num_splits - handle.append((addr + fa_slot * chunk, chunk, dev)) + chunk = local_block_len // len(fa_slices) + if remote_rank in fa_slices: + fa_offset = ( + fa_slices[remote_rank].local_write_offset + * local_block_len + // len(fa_slices[remote_rank].local_shard) + ) + handle.append((addr + fa_offset, chunk, dev)) + else: + # GQA-deduped rank: no FA transfer issued for this + # rank (empty block_ids at transfer time), but NIXL + # requires a registered descriptor. Offset 0 is safe + # because this descriptor is never used at transfer. + handle.append((addr, chunk, dev)) else: - chunk = local_len // ssm_num_splits - handle.append((addr + p_idx * chunk, chunk, dev)) + # Assume SSM always sharded + assert ( + ssm_idx is not None + and remote_rank in self.tp_mappings[engine_id][ssm_idx] + ) + chunk = local_block_len // ssm_num_splits + handle.append((addr + remote_idx * chunk, chunk, dev)) yield handle def __init__( @@ -438,14 +469,10 @@ def __init__( self._physical_blocks_per_logical_kv_block = 1 self._sync_block_size_with_kernel() - # Unwrap UniformTypeKVCacheSpecs to get the representative spec type - self._group_spec_types = tuple( - get_representative_spec_type(g.kv_cache_spec) - for g in self.kv_cache_config.kv_cache_groups - ) - # Per-engine TP mappings. Generated during handshake. - self.tp_mappings: dict[EngineId, TPMapping] = {} + # tp_mappings[engine_id][group_idx] = {rank: TPTransferSlice, ...} + self.tp_mappings: dict[EngineId, tuple[dict[int, TPTransferSlice], ...]] = {} + self.remote_ranks: dict[EngineId, tuple[int, ...]] = {} self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( "enforce_handshake_compat", True @@ -1134,17 +1161,26 @@ def _build_fa_local( def _build_fa_remote( self, - plan: TPMapping, + engine_id: EngineId, nixl_agent_meta: NixlAgentMetadata, block_size_ratio: int, ) -> list[tuple[int, int, int]]: """Build remote FA descriptors for all layers.""" assert self.transfer_topo is not None fa_group_idx = next( - i for i, t in enumerate(self._group_spec_types) if _is_attention_spec(t) + i + for i, group in enumerate(self.kv_cache_config.kv_cache_groups) + if isinstance(self._get_representative_spec(group), AttentionSpec) ) - num_attn_reads = len(plan.source_ranks_per_group[fa_group_idx]) - num_blocks = nixl_agent_meta.num_blocks + fa_slices = self.tp_mappings[engine_id][fa_group_idx] + num_attn_reads = len(fa_slices) + + # Head offset into remote rank's tensor. + # D_TP >= P_TP: single slice with non-zero offset. + # P_TP > D_TP: all slices read from offset 0. + fa_slice = next(iter(fa_slices.values())) + assert num_attn_reads == 1 or fa_slice.remote_read_offset == 0 + result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): # Read our whole local region size from remote.. @@ -1157,27 +1193,45 @@ def _build_fa_remote( local_block_len = remote_kv_block_len local_block_len = local_block_len // num_attn_reads - rank_offset = plan.rank_offset_factor * remote_kv_block_len + remote_block_len = nixl_agent_meta.block_lens[i] + if self.transfer_topo.is_kv_layout_blocks_first: + remote_block_len //= 2 + rank_offset = ( + fa_slice.remote_read_offset + * remote_block_len + // len(fa_slice.remote_shard) + ) page_size = nixl_agent_meta.block_lens[i] - for block_id in range(num_blocks): + for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * page_size - # For each block, grab the kv heads chunk belonging to current local - # tp rank of size local_block_len. - addr = base_addr + block_offset + rank_offset - result.append((addr, local_block_len, nixl_agent_meta.device_id)) + # For each block, grab the kv heads chunk belonging to current + # local tp rank of size local_block_len. + result.append( + ( + base_addr + block_offset + rank_offset, + local_block_len, + nixl_agent_meta.device_id, + ) + ) if self.transfer_topo.virtually_split_kv_in_blocks: # With FlashInfer index V separately to allow head splitting. - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=False + second_split = ( + self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=False, mamba_view=False + ) + // num_attn_reads ) - second_split = second_split // num_attn_reads - for block_id in range(num_blocks): + for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * page_size - addr = base_addr + block_offset + rank_offset # Hop over the first split of remote page, K, to read V. - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + v_addr = ( + base_addr + + block_offset + + rank_offset + + nixl_agent_meta.block_lens[i] // 2 + ) result.append((v_addr, second_split, nixl_agent_meta.device_id)) return result @@ -1300,10 +1354,17 @@ def add_remote_agent( transfer_topo.register_remote_engine(engine_id, transfer_info) logger.info("Transfer plan: %s", transfer_topo.describe(engine_id)) - self.tp_mappings[engine_id] = compute_tp_mapping( - transfer_topology=transfer_topo, - remote_tp_size=remote_tp_size, - group_spec_types=self._group_spec_types, + self.tp_mappings[engine_id] = tuple( + self._get_representative_spec(group).get_tp_transfer_slices( + transfer_topo.tp_rank, + transfer_topo.tp_size, + remote_tp_size, + self.model_config.get_total_num_kv_heads(), + ) + for group in self.kv_cache_config.kv_cache_groups + ) + self.remote_ranks[engine_id] = tuple( + sorted({r for group_map in self.tp_mappings[engine_id] for r in group_map}) ) remote_agent_name = self.nixl_wrapper.add_remote_agent( @@ -1339,8 +1400,6 @@ def add_remote_agent( tp_ratio, ) - plan = self.tp_mappings[engine_id] - ### (Optional) Register local agent memory regions. MLA is not split. if ( tp_ratio < 0 @@ -1352,8 +1411,8 @@ def add_remote_agent( # we only do this once per remote tp_size (replica-friendly). self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] - for handle_data in self._build_local_splits_from_plan( - plan, + for handle_data in self._build_local_splits( + engine_id, self.src_blocks_data, self.num_descs, ): @@ -1371,7 +1430,7 @@ def add_remote_agent( # Register all remote blocks, but only the corresponding kv heads. blocks_data = self._build_fa_remote( - plan, + engine_id, nixl_agent_meta, block_size_ratio, ) @@ -2021,7 +2080,6 @@ def _send_heartbeats(self, metadata: NixlConnectorMetadata) -> None: def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): assert meta.remote is not None and self.transfer_topo is not None engine_id = meta.remote.engine_id - plan = self.tp_mappings[engine_id] remote_info = self.transfer_topo.get_engine_info(engine_id) tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) @@ -2037,18 +2095,18 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_rank=rank, local_block_ids=[ list(local_block_ids[g]) - if rank in plan.source_ranks_per_group[g] + if rank in self.tp_mappings[engine_id][g] else [] for g in range(num_groups) ], remote_block_ids=[ list(remote_block_ids[g]) - if rank in plan.source_ranks_per_group[g] + if rank in self.tp_mappings[engine_id][g] else [] for g in range(num_groups) ], ) - for rank in plan.all_source_ranks + for rank in self.remote_ranks[engine_id] ] # D may have to perform multiple reads from different remote ranks. @@ -2333,7 +2391,12 @@ def _apply_prefix_caching( for i, remote_group in enumerate(remote_block_ids): num_local_blocks = len(local_block_ids[i]) num_remote_blocks = len(remote_group) - if _is_ssm_spec(self._group_spec_types[i]): + if isinstance( + self._get_representative_spec( + self.kv_cache_config.kv_cache_groups[i] + ), + MambaSpec, + ): assert num_local_blocks == num_remote_blocks else: max_padding = max( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 31ee89bc72aa..3255d5d28e14 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -78,6 +78,105 @@ def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool: return get_kv_quant_mode(kv_cache_dtype).is_per_token_head +# --------------------------------------------------------------------------- +# TP transfer slice +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class ShardRange: + """A contiguous [start, stop) range along any sharding dimension. + + Carries global_size for safety assertions (prevents mixing ranges from + different dimension spaces). + """ + + start: int + stop: int + global_size: int + + def __post_init__(self): + assert 0 <= self.start <= self.stop <= self.global_size, ( + f"Invalid ShardRange [{self.start}:{self.stop}] " + f"for global_size={self.global_size}" + ) + + def __len__(self) -> int: + return self.stop - self.start + + def offset_within(self, parent: ShardRange) -> int: + """Return self.start's position within parent's range. + + Asserts both ranges share the same global_size and self is within parent. + """ + assert self.global_size == parent.global_size, ( + f"Dimension mismatch: {self.global_size} vs {parent.global_size}" + ) + assert self.start >= parent.start and self.stop <= parent.stop, ( + f"{self} is not within {parent}" + ) + return self.start - parent.start + + def intersect(self, other: ShardRange) -> ShardRange | None: + """Find overlap with another range. Returns None if disjoint.""" + assert self.global_size == other.global_size, ( + f"Dimension mismatch: {self.global_size} vs {other.global_size}" + ) + lo = max(self.start, other.start) + hi = min(self.stop, other.stop) + if lo >= hi: + return None + return ShardRange(lo, hi, self.global_size) + + def __repr__(self) -> str: + return f"[{self.start}:{self.stop}]/{self.global_size}" + + +@dataclass(frozen=True) +class TPTransferSlice: + """Describes what KV heads to read from one remote rank. + + All ShardRanges are in global head coordinates (over total_num_kv_heads). + transfer_range is the intersection of remote_shard and local_shard. + """ + + remote_rank: int + remote_shard: ShardRange + local_shard: ShardRange + transfer_range: ShardRange + + def __post_init__(self): + assert self.transfer_range.global_size == self.remote_shard.global_size, ( + f"Dimension mismatch: transfer_range {self.transfer_range.global_size} " + f"vs remote_shard {self.remote_shard.global_size}" + ) + assert ( + self.transfer_range.start >= self.remote_shard.start + and self.transfer_range.stop <= self.remote_shard.stop + ), ( + f"transfer_range {self.transfer_range} " + f"not within remote_shard {self.remote_shard}" + ) + + @property + def remote_read_offset(self) -> int: + """Element offset into remote rank's tensor to start reading.""" + return self.transfer_range.offset_within(self.remote_shard) + + @property + def local_write_offset(self) -> int: + """Element offset into local tensor to start writing.""" + return self.transfer_range.offset_within(self.local_shard) + + def __repr__(self) -> str: + return ( + f"TPTransferSlice(rank={self.remote_rank}, " + f"transfer={self.transfer_range}, " + f"remote={self.remote_shard}, " + f"local={self.local_shard})" + ) + + class KVCacheSpecKind(str, Enum): FULL_ATTENTION = "full_attention" MLA_ATTENTION = "mla_attention" @@ -129,6 +228,23 @@ def copy_with_new_block_size(self, block_size: int) -> Self: """ return replace(self, block_size=block_size) + def get_tp_transfer_slices( + self, + local_tp_rank: int, + local_tp_size: int, + remote_tp_size: int, + total_num_kv_heads: int, + ) -> dict[int, TPTransferSlice]: + """Compute transfer slices for this local rank. + + Returns a mapping from remote_rank -> TPTransferSlice describing + which remote ranks to read from and what sub-range to transfer. + Must be overridden by subclasses that participate in PD transfers. + """ + raise NotImplementedError( + f"{type(self).__name__} does not implement get_tp_transfer_slices" + ) + @classmethod def merge(cls, specs: list[Self]) -> Self: """ @@ -183,6 +299,81 @@ def real_page_size_bytes(self) -> int: * get_dtype_size(self.dtype) ) + # ------------------------------------------------------------------ + # TP transfer slice interface + # ------------------------------------------------------------------ + + def get_tp_transfer_slices( + self, + local_tp_rank: int, + local_tp_size: int, + remote_tp_size: int, + total_num_kv_heads: int, + ) -> dict[int, TPTransferSlice]: + """Compute transfer slices for this local rank. + + Returns rank -> TPTransferSlice mapping. Logic mirrors the old + compute_tp_mapping attention-rank selection on main. + """ + + def _shard_for_rank(rank: int, tp_size: int) -> ShardRange: + s = rank * total_num_kv_heads // tp_size + e = (rank + 1) * total_num_kv_heads // tp_size + if s == e: + # Replicated: this rank holds same head as a neighbor. + # Express as size-1 shard for the head it actually holds. + return ShardRange( + s, s + max(1, total_num_kv_heads // tp_size), total_num_kv_heads + ) + return ShardRange(s, e, total_num_kv_heads) + + local_shard = _shard_for_rank(local_tp_rank, local_tp_size) + + if local_tp_size >= remote_tp_size: + # D_TP >= P_TP: read from one remote rank. + remote_rank = local_tp_rank * remote_tp_size // local_tp_size + remote_shard = _shard_for_rank(remote_rank, remote_tp_size) + transfer_range = remote_shard.intersect(local_shard) + assert transfer_range is not None, ( + f"local_shard {local_shard} and remote_shard {remote_shard} " + f"are disjoint for rank {remote_rank}" + ) + return { + remote_rank: TPTransferSlice( + remote_rank=remote_rank, + remote_shard=remote_shard, + local_shard=local_shard, + transfer_range=transfer_range, + ) + } + else: + # P_TP > D_TP: read from multiple remotes with GQA dedup. + abs_tp = remote_tp_size // local_tp_size + start = local_tp_rank * abs_tp + + result: dict[int, TPTransferSlice] = {} + seen_heads: set[int] = set() + for r in range(start, start + abs_tp): + head_start = r * total_num_kv_heads // remote_tp_size + if head_start in seen_heads: + continue + seen_heads.add(head_start) + + remote_shard = _shard_for_rank(r, remote_tp_size) + transfer_range = remote_shard.intersect(local_shard) + assert transfer_range is not None, ( + f"local_shard {local_shard} and remote_shard {remote_shard} " + f"are disjoint for rank {r}" + ) + result[r] = TPTransferSlice( + remote_rank=r, + remote_shard=remote_shard, + local_shard=local_shard, + transfer_range=transfer_range, + ) + + return result + @dataclass(frozen=True, kw_only=True) class FullAttentionSpec(AttentionSpec): @@ -346,6 +537,32 @@ def __post_init__(self): super().__post_init__() _apply_alignment_padding(self) + # ------------------------------------------------------------------ + # TP transfer slice interface (MLA: cache is always replicated) + # ------------------------------------------------------------------ + + def get_tp_transfer_slices( + self, + local_tp_rank: int, + local_tp_size: int, + remote_tp_size: int, + total_num_kv_heads: int, + ) -> dict[int, TPTransferSlice]: + """MLA cache is fully replicated -- read full block from one remote. + + Load-balances by picking the aligned remote rank. + """ + aligned_remote = local_tp_rank * remote_tp_size // local_tp_size + shard = ShardRange(0, 1, 1) + return { + aligned_remote: TPTransferSlice( + remote_rank=aligned_remote, + remote_shard=shard, + local_shard=shard, + transfer_range=shard, + ) + } + @property def storage_block_size(self) -> int: return self.block_size // self.compress_ratio @@ -507,6 +724,32 @@ class SlidingWindowMLASpec(SlidingWindowSpec): def __post_init__(self): _apply_alignment_padding(self) + # ------------------------------------------------------------------ + # TP transfer slice interface (MLA: cache is always replicated) + # ------------------------------------------------------------------ + + def get_tp_transfer_slices( + self, + local_tp_rank: int, + local_tp_size: int, + remote_tp_size: int, + total_num_kv_heads: int, + ) -> dict[int, TPTransferSlice]: + """MLA cache is fully replicated -- read full block from one remote. + + Load-balances by picking the aligned remote rank. + """ + aligned_remote = local_tp_rank * remote_tp_size // local_tp_size + shard = ShardRange(0, 1, 1) + return { + aligned_remote: TPTransferSlice( + remote_rank=aligned_remote, + remote_shard=shard, + local_shard=shard, + transfer_range=shard, + ) + } + @property def storage_block_size(self) -> int: return self.block_size // self.compress_ratio @@ -579,6 +822,45 @@ def page_size_bytes(self) -> int: return self.page_size_padded return page_size + def get_tp_transfer_slices( + self, + local_tp_rank: int, + local_tp_size: int, + remote_tp_size: int, + total_num_kv_heads: int, + ) -> dict[int, TPTransferSlice]: + """Mamba SSM state is TP-sharded but not along KV heads. + + The actual byte-level sub-projection slicing (conv x/B/C + ssm) + is handled by _build_mamba_remote via MambaConvSplitInfo. + Here we only determine which remote ranks to read from. + Uses a placeholder ShardRange(0,1,1) since the real byte-level + decomposition is handled by MambaConvSplitInfo. + """ + shard = ShardRange(0, 1, 1) + if local_tp_size >= remote_tp_size: + remote_rank = local_tp_rank * remote_tp_size // local_tp_size + return { + remote_rank: TPTransferSlice( + remote_rank=remote_rank, + remote_shard=shard, + local_shard=shard, + transfer_range=shard, + ) + } + else: + abs_tp = remote_tp_size // local_tp_size + start = local_tp_rank * abs_tp + return { + r: TPTransferSlice( + remote_rank=r, + remote_shard=shard, + local_shard=shard, + transfer_range=shard, + ) + for r in range(start, start + abs_tp) + } + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: if vllm_config.cache_config.mamba_cache_mode == "all": max_model_len = vllm_config.model_config.max_model_len