diff --git a/tests/v1/kv_connector/unit/test_mooncake_connector.py b/tests/v1/kv_connector/unit/test_mooncake_connector.py index 7b6fe3af0ce2..83202e023c9e 100644 --- a/tests/v1/kv_connector/unit/test_mooncake_connector.py +++ b/tests/v1/kv_connector/unit/test_mooncake_connector.py @@ -631,7 +631,7 @@ def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes(): mock_thread.return_value.is_alive.return_value = False worker.use_mla = True - worker.kv_topo.is_mla = True + worker.transfer_topo.is_mla = True # MLA cache tensor: shape[-2] is the block size. mla_cache = torch.zeros((2, 16, 96), dtype=torch.float16) @@ -692,9 +692,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): # Override TP rank/size to simulate P TP=2 prefill_worker.tp_rank = P_TP_RANK prefill_worker.tp_size = P_TP_SIZE - # Update shared dict so kv_topo sees correct TP size prefill_worker._tp_size[prefill_worker.engine_id] = P_TP_SIZE - prefill_worker.kv_topo.tp_rank = P_TP_RANK + prefill_worker.transfer_topo.tp_rank = P_TP_RANK + prefill_worker.transfer_topo.tp_size = P_TP_SIZE prefill_worker.kv_caches_base_addr = [0x1000] prefill_worker.block_len_per_layer = [local_block_len] @@ -714,7 +714,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): send_meta.ready.set() # Compute target D ranks using the production code path - target_d_ranks = prefill_worker.kv_topo.get_target_remote_ranks(d_tp_size) + target_d_ranks = prefill_worker.transfer_topo.handshake_target_ranks(d_tp_size) mock_socket = AsyncMock(spec=zmq.asyncio.Socket) mock_socket.send_multipart = AsyncMock() diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index d67b14e8dd4a..50e83aa2ef20 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -21,7 +21,7 @@ from vllm.config import KVTransferConfig, set_current_vllm_config from vllm.distributed.kv_transfer.kv_connector.utils import ( KVOutputAggregator, - TpKVTopology, + TransferTopology, get_current_attn_backend, ) from vllm.distributed.kv_transfer.kv_connector.v1 import nixl @@ -463,19 +463,20 @@ def __init__( test_shape = self.attn_backends[0].get_kv_cache_shape( num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 ) - self.kv_topo = TpKVTopology( + self.transfer_topo = TransferTopology( tp_rank=self.tp_rank, + tp_size=self.world_size, + block_size=self.block_size, engine_id=self.engine_id, - remote_tp_size=self._tp_size, # shared state - remote_block_size=self._block_size, # shared state is_mla=self.use_mla, + is_mamba=False, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, tensor_shape=test_shape, ) self.compat_hash = compute_nixl_compatibility_hash( - self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks + self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks ) def _nixl_handshake( @@ -496,7 +497,7 @@ def _nixl_handshake( # Adjust remote block length metadata to satisfy heterogeneous TP # invariants enforced during handshake validation. remote_block_lens = list(self.block_len_per_layer) - tp_ratio = self.kv_topo.tp_ratio(remote_tp_size) + tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size) if remote_tp_size > self.world_size: # P TP > D TP case, block_len of remote is smaller remote_block_lens = [ @@ -731,8 +732,9 @@ def check_handshake(remote_tp_size: int): assert set(remote_agents.keys()) == set(range(tp_ratio)) remote_engine_id = worker.REMOTE_ENGINE_ID - assert worker._tp_size[remote_engine_id] == remote_tp_size - assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id) + remote_info = worker.transfer_topo.get_engine_info(remote_engine_id) + assert remote_info.remote_tp_size == remote_tp_size + assert -tp_ratio == worker.transfer_topo.tp_ratio(remote_tp_size) # ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio @@ -796,7 +798,7 @@ def test_prefill_tp_size_greater_than_decode_tp_size_mla( (conn_p0.connector_worker, conn_p1.connector_worker) ): worker.world_size = p_tp_size - worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size} + worker.transfer_topo.tp_size = p_tp_size worker.tp_rank = rank worker.use_mla = True @@ -2337,7 +2339,7 @@ def test_compatibility_hash_validation( remote_hash = compute_nixl_compatibility_hash( remote_vllm_config, decode_worker.backend_name, - decode_worker.kv_topo.cross_layers_blocks, + decode_worker.transfer_topo.cross_layers_blocks, ) prefill_block_size = config_overrides.get("block_size", 16) @@ -2424,12 +2426,13 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) test_shape = backend.get_kv_cache_shape( num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 ) - decode_worker.kv_topo = TpKVTopology( + decode_worker.transfer_topo = TransferTopology( tp_rank=decode_worker.tp_rank, + tp_size=decode_worker.world_size, + block_size=decode_worker.block_size, engine_id=decode_worker.engine_id, - remote_tp_size=decode_worker._tp_size, # shared state - remote_block_size=decode_worker._block_size, # shared state is_mla=decode_worker.use_mla, + is_mamba=False, total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), attn_backends=[backend], tensor_shape=test_shape, @@ -2438,7 +2441,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) decode_worker.compat_hash = compute_nixl_compatibility_hash( decode_worker.vllm_config, decode_worker.backend_name, - decode_worker.kv_topo.cross_layers_blocks, + decode_worker.transfer_topo.cross_layers_blocks, ) if error_scenario == "handshake_decode_error": diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 30913ff98ee2..5b6090173591 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -152,13 +152,14 @@ def test_read_blocks_for_req_expands_remote_ids( remote_engine_id = "remote-engine" if has_mamba: - worker._mamba_phys_ratio = {remote_engine_id: remote_ratio} + worker._physical_blocks_per_logical = {remote_engine_id: remote_ratio} - # Mock kv_topo: empty remote ranks skips the transfer machinery entirely, - # isolating the block-ID expansion logic. - worker.kv_topo = MagicMock() - worker.kv_topo.get_target_remote_ranks_from_engine_id.return_value = [] - worker.kv_topo.tp_ratio_from_engine_id.return_value = 1 + # Mock transfer_topo: empty remote ranks skips the transfer machinery + # entirely, isolating the block-ID expansion logic. + worker.transfer_topo = MagicMock() + worker.transfer_topo.target_remote_ranks.return_value = [] + worker.transfer_topo.get_engine_info.return_value = MagicMock(remote_tp_size=1) + worker.transfer_topo.tp_ratio.return_value = 1 metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( @@ -317,7 +318,7 @@ def test_get_block_descs_ids_hybrid_ssm(): worker._has_mamba = True worker._is_mamba_group = [False, True] worker._physical_blocks_per_logical_kv_block = 1 - worker._mamba_phys_ratio = {engine_id: 1} + worker._physical_blocks_per_logical = {engine_id: 1} worker.block_len_per_layer = [100] # num_descs = num_regions * num_blocks (no blocks_first doubling) worker.num_descs = 2 * num_blocks @@ -355,7 +356,7 @@ def test_get_block_descs_ids_kernel_block_mismatch(): worker._has_mamba = True worker._is_mamba_group = [False, True] worker._physical_blocks_per_logical_kv_block = ratio - worker._mamba_phys_ratio = {engine_id: ratio} + worker._physical_blocks_per_logical = {engine_id: ratio} worker.block_len_per_layer = [100] worker.num_descs = 2 * num_blocks # 800 @@ -532,15 +533,15 @@ def test_has_mamba_init( ((9216, 524288), 4096, 131), ], ) -def test_compute_mamba_phys_ratio(ssm_sizes, block_len, expected_ratio): - """Verify that compute_mamba_phys_ratio is TP-dependent. +def test_compute_physical_blocks_per_logical(ssm_sizes, block_len, expected_ratio): + """Verify that compute_physical_blocks_per_logical is TP-dependent. With dimension-sharded Mamba state, the ratio differs across TP sizes (e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why - _mamba_phys_ratio must be stored per-engine. + _physical_blocks_per_logical must be stored per-engine. """ from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( - compute_mamba_phys_ratio, + compute_physical_blocks_per_logical, ) - assert compute_mamba_phys_ratio(ssm_sizes, block_len) == expected_ratio + assert compute_physical_blocks_per_logical(ssm_sizes, block_len) == expected_ratio diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index e75c1c0a3a45..63b56eddfaed 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -5,7 +5,7 @@ """ from collections.abc import Iterator -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, cast import torch @@ -319,31 +319,139 @@ def yield_req_data( ) -@dataclass -class TpKVTopology: +def get_current_attn_backends( + vllm_config: VllmConfig, layer_names: list[str] | None = None +) -> list[type[AttentionBackend]]: + """Get all distinct attention backends for the given layers. + + Args: + vllm_config: The current vLLM configuration. + layer_names: Optional list of layer names to scope the lookup. + When None, all attention layers are considered. + + Returns: + Deduplicated list of attention backend classes. + """ + layer_type = cast(type[Any], AttentionLayerBase) + layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names) + if layers: + seen: dict[str, type[AttentionBackend]] = {} + for layer in layers.values(): + backend = layer.get_attn_backend() + seen[backend.full_cls_name()] = backend + return list(seen.values()) + + # Fallback for tests, when static_forward_context is empty. + logger.debug( + "No layers found in the vLLM config. Falling back to default attention backend." + ) + from vllm.v1.attention.selector import get_attn_backend + + return [ + get_attn_backend( + head_size=vllm_config.model_config.get_head_size(), + dtype=vllm_config.model_config.dtype, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + use_mla=vllm_config.model_config.use_mla, + ) + ] + + +def get_current_attn_backend( + vllm_config: VllmConfig, layer_names: list[str] | None = None +) -> type[AttentionBackend]: + """Get the first attention backend for the given layers.""" + return get_current_attn_backends(vllm_config, layer_names)[0] + + +# ---- Per-engine transfer info ---- + + +@dataclass(frozen=True) +class EngineTransferInfo: + """Common per-remote-engine transfer state, computed at handshake. + + Stored per ``engine_id`` inside ``TransferTopology._engines``. """ - Helper class for tensor parallel and KV topology information for - mapping between local and remote TP workers. + + remote_tp_size: int + + remote_block_len: int + """Block length (bytes)""" + + remote_block_size: int + """Tokens per block.""" + + remote_physical_blocks_per_logical: int + """Physical blocks per logical block.""" + + +@dataclass(frozen=True) +class MambaEngineTransferInfo(EngineTransferInfo): + """Extends ``EngineTransferInfo`` with Mamba-hybrid transfer geometry. + + For hybrid SSM+Attention models, FA and Mamba layers may require + different numbers of reads from different remote ranks. This + dataclass captures that per-engine transfer plan. """ + remote_fa_source_ranks: tuple[int, ...] + """Remote ranks carrying unique FA heads for this local rank.""" + + remote_all_source_ranks: tuple[int, ...] + """All remote ranks this local rank reads from (FA + Mamba).""" + + remote_num_fa_reads: int + """Number of distinct remote ranks needed for FA data.""" + + remote_num_mamba_reads: int + """Number of distinct remote ranks needed for Mamba data.""" + + remote_fa_descriptor_bytes: int + """Byte size of one FA K (or V) descriptor entry.""" + + is_remote_replicated: bool + """Whether the remote engine has replicated KV heads + (remote_tp_size > total_num_kv_heads).""" + + remote_physical_heads: int + """Physical KV heads stored per remote rank.""" + + +# ---- Transfer topology ---- + + +@dataclass +class TransferTopology: + """Single source of truth for local TP identity and per-engine remote info.""" + tp_rank: int - remote_tp_size: dict[EngineId, int] + tp_size: int + block_size: int + engine_id: EngineId is_mla: bool + is_mamba: bool total_num_kv_heads: int attn_backends: list[type[AttentionBackend]] - engine_id: EngineId - remote_block_size: dict[EngineId, int] tensor_shape: torch.Size | None = None - is_mamba: bool = False def __post_init__(self): + self.local_physical_heads = max(1, self.total_num_kv_heads // self.tp_size) + + self._engines: dict[EngineId, EngineTransferInfo] = {} + self._fa_source_sets: dict[EngineId, frozenset[int]] = {} + self._fa_source_indices: dict[EngineId, dict[int, int]] = {} + # Figure out whether the first dimension of the cache is K/V - # or num_blocks. This is used to register the memory regions correctly. + # or num_blocks. attn_backend = self.attn_backends[0] if not self.is_mamba: _MOCK_BLOCK_SIZE = 16 kv_cache_shape: tuple[int, ...] = attn_backend.get_kv_cache_shape( - num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1 + num_blocks=1, + block_size=_MOCK_BLOCK_SIZE, + num_kv_heads=1, + head_size=1, ) logger.debug("Test kv_cache_shape: %s", kv_cache_shape) # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], @@ -358,11 +466,9 @@ def __post_init__(self): self._cross_layers_blocks = ( len(self.tensor_shape) == len(kv_cache_shape) + 1 ) - self.tensor_shape: torch.Size if self._cross_layers_blocks: logger.debug("Using cross-layer KV cache") - # prepend layers dimension _MOCK_NUM_LAYERS = 80 kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape try: @@ -372,15 +478,81 @@ def __post_init__(self): except (AttributeError, NotImplementedError): assert self.tensor_shape is not None kv_cache_stride_order = tuple(range(len(self.tensor_shape))) - - # In case of cross layers permute kv_cache_shape according to - # stride_order to retrieve physical position of block_size kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + # ============================================================ + # Engine registration + # ============================================================ + + def register_remote_engine( + self, + remote_engine_id: EngineId, + remote_tp_size: int, + remote_block_size: int, + remote_block_len: int, + remote_physical_blocks_per_logical: int, + *, + local_block_len: int = 0, + ) -> EngineTransferInfo: + """Register a remote engine, unifying worker dicts state. + + Only remote engines should be registered here — the local engine's + identity (tp_size, block_size, etc.) is set via ``__init__`` params. + + For Mamba models, also computes the Mamba transfer plan and + builds the FA source lookup caches. + + Args: + local_block_len: Local representative block_len (bytes). + Required for Mamba models to compute ``fa_descriptor_bytes``. + """ + assert remote_engine_id != self.engine_id, ( + f"Cannot register local engine {self.engine_id} as remote. " + f"Local identity is set via __init__ params." + ) + if remote_engine_id in self._engines: + return self._engines[remote_engine_id] + info: EngineTransferInfo + if self.is_mamba: + info = self._build_mamba_info( + remote_tp_size=remote_tp_size, + remote_block_size=remote_block_size, + remote_block_len=remote_block_len, + remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), + local_block_len=local_block_len, + ) + assert isinstance(info, MambaEngineTransferInfo) + self._fa_source_sets[remote_engine_id] = frozenset( + info.remote_fa_source_ranks + ) + self._fa_source_indices[remote_engine_id] = { + r: i for i, r in enumerate(info.remote_fa_source_ranks) + } + else: + info = EngineTransferInfo( + remote_tp_size=remote_tp_size, + remote_block_len=remote_block_len, + remote_block_size=remote_block_size, + remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), + ) + self._engines[remote_engine_id] = info + return info + + def get_engine_info(self, remote_engine_id: EngineId) -> EngineTransferInfo: + return self._engines[remote_engine_id] + + # ============================================================ + # Layout properties + # ============================================================ + @property def is_kv_layout_blocks_first(self) -> bool: return self._is_kv_layout_blocks_first + @property + def cross_layers_blocks(self) -> bool: + return self._cross_layers_blocks + @property def split_k_and_v(self) -> bool: # Whether to register regions for K and V separately (when present). @@ -388,29 +560,16 @@ def split_k_and_v(self) -> bool: self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first ) - @property - def tp_size(self) -> int: - return self.remote_tp_size[self.engine_id] + # ============================================================ + # Common methods + # ============================================================ - @property - def block_size(self) -> int: - return self.remote_block_size[self.engine_id] + def tp_ratio(self, remote_tp_size: int) -> int: + """Calculate the tensor parallel ratio between local and remote TP. - @property - def cross_layers_blocks(self) -> bool: - return self._cross_layers_blocks - - def tp_ratio( - self, - remote_tp_size: int, - ) -> int: - """ - Calculate the tensor parallel ratio between local and remote TP. - We can think of it as the number of local TP workers-per-remote TP - workers. Local workers will read from the same remote TP worker in - groups of size `tp_ratio`.If remote tp_size > local tp_size, the - ratio is flipped (remote_size/local_size) and the returned value is - negative. + Positive when local_tp >= remote_tp (local workers read from the + same remote worker in groups of size ``tp_ratio``). Negative when + remote_tp > local_tp (ratio is flipped). """ if self.tp_size >= remote_tp_size: assert self.tp_size % remote_tp_size == 0, ( @@ -418,78 +577,65 @@ def tp_ratio( f"by remote tensor parallel size {remote_tp_size}." ) return self.tp_size // remote_tp_size - assert remote_tp_size % self.tp_size == 0, ( f"Remote tensor parallel size {remote_tp_size} is not divisible " f"by local tensor parallel size {self.tp_size}." ) - # P TP > D TP case, return the ratio as negative - return -remote_tp_size // self.tp_size + return -(remote_tp_size // self.tp_size) - def block_size_ratio( - self, - remote_block_size: int, - ) -> int: - """ - Calculate the block size ratio between local and remote TP. - """ + def block_size_ratio(self, remote_block_size: int) -> int: + """Calculate the block size ratio between local and remote.""" assert self.block_size % remote_block_size == 0, ( f"Local block size {self.block_size} is not divisible " f"by remote block size {remote_block_size} or vice versa." ) return self.block_size // remote_block_size - def tp_ratio_from_engine_id( - self, - remote_engine_id: EngineId, - ) -> int: - remote_tp_size = self.remote_tp_size[remote_engine_id] - return self.tp_ratio(remote_tp_size) - - def block_size_ratio_from_engine_id( - self, - remote_engine_id: EngineId, - ) -> int: - remote_block_size = self.remote_block_size[remote_engine_id] - return self.block_size_ratio(remote_block_size) - - def is_kv_replicated(self, engine_id: EngineId) -> bool: - """ - Whether the KV cache is replicated across TP workers due to the + def is_kv_replicated(self, remote_engine_id: EngineId) -> bool: + """Whether the KV cache is replicated across TP workers due to the number of TP workers being greater than the number of KV heads. - When they are equal, each TP rank still owns one distinct KV head, - so this is not considered replication. """ - tp_size = self.remote_tp_size[engine_id] - return tp_size > self.total_num_kv_heads + return self._engines[remote_engine_id].remote_tp_size > self.total_num_kv_heads def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: # MLA is always replicated as the hidden dim can't be split. return self.is_mla or self.is_kv_replicated(remote_engine_id) - def get_target_remote_ranks( - self, - remote_tp_size: int, - ) -> list[int]: - """ - Get the remote TP rank (on P) that the current local TP rank - (on D) will read from. When remote tp_size > local tp_size, we - read from multiple remote ranks. + @property + def local_replicates_kv_cache(self) -> bool: + """Whether the local engine's KV cache is replicated.""" + return self.is_mla or self.tp_size > self.total_num_kv_heads + + def handshake_target_ranks(self, remote_tp_size: int) -> list[int]: + """Pre-registration: compute which remote TP ranks to handshake with. + + Pure math based on local/remote TP sizes — does not require + the remote engine to be registered yet. """ tp_ratio = self.tp_ratio(remote_tp_size) if tp_ratio > 0: return [self.tp_rank // tp_ratio] + abs_ratio = -tp_ratio + return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)] - # P TP > D TP case, D reads from |tp_ratio| remote workers. - tp_ratio = -tp_ratio - return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)] + def target_remote_ranks(self, remote_engine_id: EngineId) -> list[int]: + """Get the remote TP rank(s) that the current local TP rank will + read from. When remote tp_size > local tp_size, reads from + multiple remote ranks. - def get_target_remote_ranks_from_engine_id( - self, - remote_engine_id: EngineId, - ) -> list[int]: - remote_tp_size = self.remote_tp_size[remote_engine_id] - return self.get_target_remote_ranks(remote_tp_size) + For Mamba models, returns the precomputed ``all_source_ranks`` + (FA + Mamba union). + """ + info = self._engines[remote_engine_id] + if isinstance(info, MambaEngineTransferInfo): + return list(info.remote_all_source_ranks) + + tp_ratio = self.tp_ratio(info.remote_tp_size) + if tp_ratio > 0: + return [self.tp_rank // tp_ratio] + # remote TP > local TP: read from |tp_ratio| remote workers + abs_ratio = -tp_ratio + return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)] def get_transfer_cache_regions( self, cache: torch.Tensor, layer_spec: "KVCacheSpec" @@ -498,331 +644,139 @@ def get_transfer_cache_regions( also accounting for hybrid SSM models specificities. """ if isinstance(layer_spec, MambaSpec): - # Register the whole kv cache shared tensor, including SSM/Conv. This is - # similar to FI with the difference that SSM/Conv have different sizes + # Register the whole kv cache shared tensor, including + # SSM/Conv. conv, ssm = cache return [conv] - # Check may be hacky but it's matching `_update_hybrid_attention_mamba_layout`. + # Check may be hacky but it's matching + # `_update_hybrid_attention_mamba_layout`. if self.is_mamba and cache.shape[0] == 2: - # When MAMBA is present, all backends are blocks first, so that blocks - # can be shared between attention layers and mamba layers. Runner - # `_update_hybrid_attention_mamba_layout` already adjusted strides - # for FlashAttn-like backends so its num_blocks first. - # Swap [2<>num_blocks] dims to get required layout for hybrid SSM. + # When MAMBA is present, all backends are blocks first, so + # that blocks can be shared between attention layers and mamba + # layers. Runner already adjusted strides for FlashAttn-like + # backends so its num_blocks first. + # Swap [2<>num_blocks] dims for hybrid SSM layout. cache = cache.transpose(0, 1) # Regular case: backends like FA register K/V in separate regions return cache if self.split_k_and_v else [cache] + # ============================================================ + # Mamba-specific methods + # ============================================================ -# ---- Mamba-HMA hetero-TP transfer config ---- -# -# Key insight: with hetero-TP (P_TP > D_TP), FA KV cache may be -# replicated across P ranks (when P_TP > num_kv_heads), but Mamba -# conv/SSM state is almost always uniquely sharded per P rank. So the -# number of P ranks D must read from can differ between FA and Mamba, -# and they must be handled separately. - - -def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: - """Physical KV head range stored in a rank's KV cache tensor. - - When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank. - When ``tp_size > num_heads``: 1 physical head per rank. Heads are - distributed **contiguously** (matching vLLM's GQA weight partitioning): - consecutive ranks share a head before moving to the next one. - """ - if tp_size <= num_heads: - assert num_heads % tp_size == 0 - per_rank = num_heads // tp_size - return range(rank * per_rank, (rank + 1) * per_rank) - else: - h = rank * num_heads // tp_size - return range(h, h + 1) - - -def _range_overlap(a: range, b: range) -> range: - start = max(a.start, b.start) - stop = min(a.stop, b.stop) - return range(start, max(start, stop)) - + def should_skip_fa(self, remote_engine_id: EngineId, remote_rank: int) -> bool: + """Whether to skip FA groups for this remote rank (mamba-only).""" + return remote_rank not in self._fa_source_sets[remote_engine_id] -@dataclass -class HeteroTPTransferConfig: - """Precomputed transfer plan for one (D rank, P engine) pair. - - Currently only instantiated for Mamba-HMA (hybrid SSM+Attention) models - where FA and mamba require different splitting factors. Could be extended - to other model types that need non-uniform hetero-TP transfer sizing. - - All descriptor sizes are computed here. The guarantee is: - local_entry_size == remote_entry_size (for NIXL) - - Attributes that start with ``fa_`` concern FlashAttention KV cache. - Attributes that start with ``mamba_`` concern Mamba conv/SSM state. - """ - - # ---- Input parameters (from handshake) ---- - tp_ratio: int - K: int # total_num_kv_heads (before TP sharding) - d_tp: int # D engine's tensor_parallel_size - p_tp: int # P engine's tensor_parallel_size - d_rank: int # this D worker's TP rank - use_mla: bool - - # Per-layer block lengths (bytes, K+V combined for blocks_first). - # Uniform across layers for current models. - d_block_len: int # D's block_len_per_layer (representative) - p_block_len: int # P's block_len_per_layer (from handshake) - is_blocks_first: bool # kv_topo.is_kv_layout_blocks_first - - # ---- Derived: computed in __post_init__ ---- - # - # Physical heads per rank (what the KV tensor actually stores) - d_physical_heads: int = field(init=False) - p_physical_heads: int = field(init=False) - - # How many distinct P ranks D needs for FA data - physical_fa_num_reads: int = field(init=False) - - # Which P ranks contribute unique FA heads (ordered by head index) - fa_read_targets: list[int] = field(init=False) - - # All P ranks needed for mamba (always abs_tp for tp_ratio < 0) - mamba_num_reads: int = field(init=False) - - # All P ranks this D rank communicates with (FA ∪ mamba) - transfer_targets: list[int] = field(init=False) - - # FA descriptor entry size (K or V side, for blocks_first layout) - # Guaranteed: fa_entry_size is the SAME for local handle AND remote desc. - fa_entry_size: int = field(init=False) - - # Replication flags - is_d_replicated: bool = field(init=False) - is_p_replicated: bool = field(init=False) - - # Pre-built set for fast lookup - _fa_target_set: frozenset[int] = field(init=False, repr=False) - # Map: P rank → index in fa_read_targets (for head slot offset) - _fa_target_index: dict[int, int] = field(init=False, repr=False) - - def __post_init__(self) -> None: - K = self.K - self.is_d_replicated = self.d_tp > K - self.is_p_replicated = self.p_tp > K - - self.d_physical_heads = max(1, K // self.d_tp) - self.p_physical_heads = max(1, K // self.p_tp) - - abs_tp = -self.tp_ratio if self.tp_ratio < 0 else 1 - - # ---- Mamba range (computed first so FA can prefer ranks in it) ---- - mamba_range: range | None = None - if self.tp_ratio < 0: - mamba_range = range(self.d_rank * abs_tp, (self.d_rank + 1) * abs_tp) - - # ---- FA read targets ---- - if self.use_mla or self.tp_ratio >= 0: - self.physical_fa_num_reads = 1 - self.fa_read_targets = ( - [0] - if self.use_mla - # Must match kv_topo.get_target_remote_ranks (d_rank // tp_ratio). - else [ - self.d_rank // self.tp_ratio if self.tp_ratio > 0 else self.d_rank - ] - ) - else: - d_needs = _physical_head_range(self.d_tp, K, self.d_rank) - # When mamba range exists, prefer P ranks within it so that - # FA targets are a subset of mamba transfer_targets (avoids - # orphaned FA targets outside the transfer loop). - search_range = mamba_range if mamba_range is not None else range(self.p_tp) - seen: set[tuple[int, int]] = set() - targets: list[int] = [] - for p in search_range: - p_has = _physical_head_range(self.p_tp, K, p) - ov = _range_overlap(d_needs, p_has) - if len(ov) > 0: - key = (ov.start, ov.stop) - if key not in seen: - seen.add(key) - targets.append(p) - if not targets: - # Fallback: search globally (should not happen in practice) - for p in range(self.p_tp): - p_has = _physical_head_range(self.p_tp, K, p) - ov = _range_overlap(d_needs, p_has) - if len(ov) > 0: - key = (ov.start, ov.stop) - if key not in seen: - seen.add(key) - targets.append(p) - self.fa_read_targets = targets - self.physical_fa_num_reads = len(targets) - - self._fa_target_set = frozenset(self.fa_read_targets) - self._fa_target_index = {r: i for i, r in enumerate(self.fa_read_targets)} - - # ---- Mamba targets ---- - if mamba_range is not None and abs_tp > self.physical_fa_num_reads: - self.mamba_num_reads = abs_tp - self.transfer_targets = list(mamba_range) - else: - self.mamba_num_reads = self.physical_fa_num_reads - self.transfer_targets = list(self.fa_read_targets) - - # ---- FA entry size ---- - # For blocks_first: block_len_per_layer includes K+V; // 2 gives K (or V). - # Use min(D, P) because D indexes into P when tp_ratio > 0, - # and P is the natural unit when tp_ratio < 0. - effective_block_len = min(self.d_block_len, self.p_block_len) - if self.is_blocks_first: - self.fa_entry_size = effective_block_len // 2 - else: - self.fa_entry_size = effective_block_len - - self._validate() - - def _validate(self) -> None: - """Cross-check internal consistency.""" - if self.is_d_replicated and self.is_p_replicated and self.tp_ratio > 0: - logger.info( - "Both-replicated hetero-TP: D_TP=%d > P_TP=%d > K=%d. " - "Using d_rank // tp_ratio routing with relative head offset.", - self.d_tp, - self.p_tp, - self.K, - ) - - # FA targets must be a subset of transfer_targets - tt_set = set(self.transfer_targets) - for t in self.fa_read_targets: - if t not in tt_set: - logger.error( - "FA target P rank %d is NOT in transfer_targets %s. " - "This will cause missed FA reads!", - t, - self.transfer_targets, - ) - - # For tp_ratio < 0 with blocks_first: D_K_half / reads should == P_K_half - if ( - self.is_blocks_first - and self.tp_ratio < 0 - and self.physical_fa_num_reads > 0 - ): - d_k_half = self.d_block_len // 2 - p_k_half = self.p_block_len // 2 - expected_local = d_k_half // self.physical_fa_num_reads - if expected_local != p_k_half: - logger.warning( - "FA size mismatch: D_K_half=%d / reads=%d = %d, " - "but P_K_half=%d. This may indicate a head count or " - "Mamba-HMA inflation inconsistency.", - d_k_half, - self.physical_fa_num_reads, - expected_local, - p_k_half, - ) + def fa_head_slot(self, remote_engine_id: EngineId, remote_rank: int) -> int: + """Index into local FA block for this remote rank's head data. - # ---- Query methods ---- - - def should_skip_fa(self, p_rank: int) -> bool: - """Whether to skip FA groups for this P rank (mamba-only transfer).""" - return p_rank not in self._fa_target_set - - def fa_head_slot(self, p_rank: int) -> int: - """Index into D's FA block for this P rank's head data. - - For P ranks in fa_read_targets, returns 0, 1, ..., reads-1. - For P ranks NOT in fa_read_targets (replicated duplicates), - returns the slot of the matching FA target with the same head. + For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1. + For ranks NOT in ``fa_source_ranks`` (replicated duplicates), + returns the slot of the matching source rank with the same head. """ - if p_rank in self._fa_target_index: - return self._fa_target_index[p_rank] - # Duplicate head: find which fa_target has the same physical head - p_head = _physical_head_range(self.p_tp, self.K, p_rank) - for target in self.fa_read_targets: - t_head = _physical_head_range(self.p_tp, self.K, target) - if _range_overlap(p_head, t_head): - return self._fa_target_index[target] - return 0 # fallback - - def fa_rank_offset(self, remote_kv_block_len: int) -> int: - """Byte offset into P's FA block for this D rank. - - When D is replicated (D_TP > K), multiple D ranks share a head. - Computes offset *relative to the target P rank's first head* - so it works regardless of how many heads P has. - When neither side replicates, falls back to tp_rank % tp_ratio. - Returns 0 when D does not index into P's block. + fa_index = self._fa_source_indices[remote_engine_id] + if remote_rank in fa_index: + return fa_index[remote_rank] + mamba_info = self._engines[remote_engine_id] + assert isinstance(mamba_info, MambaEngineTransferInfo) + K = self.total_num_kv_heads + remote_tp = mamba_info.remote_tp_size + r_head = self._physical_head_range(remote_tp, K, remote_rank) + for target in mamba_info.remote_fa_source_ranks: + t_head = self._physical_head_range(remote_tp, K, target) + if self._range_overlap(r_head, t_head): + return fa_index[target] + return 0 + + def fa_rank_offset( + self, remote_engine_id: EngineId, remote_kv_block_len: int + ) -> int: + """Byte offset into remote FA block for this local rank. + + When local TP is replicated (local_tp > K), multiple local ranks + share a head. Computes offset *relative to the target remote + rank's first head* so it works regardless of how many heads the + remote has. Returns 0 when local does not index into remote. """ - if self.use_mla or self.tp_ratio <= 0: + mamba_info = self._engines[remote_engine_id] + assert isinstance(mamba_info, MambaEngineTransferInfo) + tp_ratio = self.tp_ratio(mamba_info.remote_tp_size) + if self.is_mla or tp_ratio <= 0: return 0 - if self.is_d_replicated: - d_head = self.d_rank * self.K // self.d_tp - p_rank = self.fa_read_targets[0] - p_start = p_rank * self.K // self.p_tp - return (d_head - p_start) * remote_kv_block_len - return self.d_rank % self.tp_ratio * remote_kv_block_len - - @property - def needs_split_handles(self) -> bool: - """Whether per-P-rank split handles are needed. + K = self.total_num_kv_heads + is_local_replicated = self.tp_size > K + if is_local_replicated: + local_head = self.tp_rank * K // self.tp_size + p_rank = mamba_info.remote_fa_source_ranks[0] + p_start = p_rank * K // mamba_info.remote_tp_size + return (local_head - p_start) * remote_kv_block_len + return self.tp_rank % tp_ratio * remote_kv_block_len + + def needs_split_handles(self, remote_engine_id: EngineId) -> bool: + """Whether per-remote-rank split handles are needed. True when FA and mamba have different read counts, requiring different splitting factors in the local handle. """ - return self.tp_ratio < 0 and not self.use_mla and len(self.transfer_targets) > 1 + mamba_info = self._engines[remote_engine_id] + assert isinstance(mamba_info, MambaEngineTransferInfo) + tp_ratio = self.tp_ratio(mamba_info.remote_tp_size) + return ( + tp_ratio < 0 + and not self.is_mla + and len(mamba_info.remote_all_source_ranks) > 1 + ) def compute_split_handle_data( self, + remote_engine_id: EngineId, src_blocks_data: list[tuple[int, int, int]], num_fa_descs: int, abs_tp: int, ) -> list[list[tuple[int, int, int]]]: - """Compute per-P-rank (addr, len, tp) triples for Mamba-HMA split handles. + """Per-remote-rank (addr, len, dev) triples for Mamba-HMA split + handles. FA descriptors (indices < num_fa_descs) are sliced by - ``physical_fa_num_reads``; mamba descriptors are sliced uniformly + ``remote_num_fa_reads``; mamba descriptors are sliced uniformly by ``abs_tp``. - - Returns one list of triples per transfer target. """ + mamba_info = self._engines[remote_engine_id] + assert isinstance(mamba_info, MambaEngineTransferInfo) all_handle_data: list[list[tuple[int, int, int]]] = [] - for p_idx, p_rank in enumerate(self.transfer_targets): + for p_idx, p_rank in enumerate(mamba_info.remote_all_source_ranks): handle_data: list[tuple[int, int, int]] = [] - skip_fa = self.should_skip_fa(p_rank) - fa_slot = self.fa_head_slot(p_rank) if not skip_fa else 0 - - for j, (addr, local_len, tp) in enumerate(src_blocks_data): + skip_fa = self.should_skip_fa(remote_engine_id, p_rank) + fa_slot = self.fa_head_slot(remote_engine_id, p_rank) if not skip_fa else 0 + for j, (addr, local_len, dev) in enumerate(src_blocks_data): if j < num_fa_descs: - assert self.physical_fa_num_reads >= 1 - fa_chunk = local_len // self.physical_fa_num_reads - handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, tp)) + assert mamba_info.remote_num_fa_reads >= 1 + fa_chunk = local_len // mamba_info.remote_num_fa_reads + handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, dev)) else: mamba_chunk = local_len // abs_tp - handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, tp)) + handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, dev)) all_handle_data.append(handle_data) return all_handle_data def filter_block_ids_for_rank( self, + remote_engine_id: EngineId, remote_rank: int, local_ids: BlockIds, remote_ids: BlockIds, is_mamba_group: list[bool], ) -> tuple[BlockIds, BlockIds]: - """Zero out FA groups for P ranks outside fa_read_targets. + """Zero out FA groups for remote ranks outside ``fa_source_ranks``. Returns (filtered_local_ids, filtered_remote_ids). When the - remote rank carries FA data for this D rank, returns the inputs - unchanged. + remote rank carries FA data for this local rank, returns the + inputs unchanged. """ - if not self.should_skip_fa(remote_rank): + if not self.should_skip_fa(remote_engine_id, remote_rank): return local_ids, remote_ids num_groups = len(local_ids) filtered_local: list[list[int]] = [ @@ -833,108 +787,184 @@ def filter_block_ids_for_rank( ] return filtered_local, filtered_remote - def describe(self) -> str: - """One-line summary for logging.""" - return ( - f"HeteroTPTransferConfig(" - f"tp_ratio={self.tp_ratio}, K={self.K}, " - f"d_tp={self.d_tp}, p_tp={self.p_tp}, d_rank={self.d_rank}, " - f"physical_fa_reads={self.physical_fa_num_reads}, " - f"mamba_reads={self.mamba_num_reads}, " - f"fa_targets={self.fa_read_targets}, " - f"transfer_targets={self.transfer_targets}, " - f"fa_entry_size={self.fa_entry_size}, " - f"d_block_len={self.d_block_len}, p_block_len={self.p_block_len})" + def describe(self, remote_engine_id: EngineId) -> str: + """One-line summary of transfer config for logging.""" + info = self._engines[remote_engine_id] + base = ( + f"tp_ratio={self.tp_ratio(info.remote_tp_size)}, " + f"K={self.total_num_kv_heads}, " + f"local_tp={self.tp_size}, " + f"remote_tp={info.remote_tp_size}, " + f"local_rank={self.tp_rank}, " + f"remote_block_len={info.remote_block_len}" ) + if isinstance(info, MambaEngineTransferInfo): + return ( + f"TransferTopology.mamba({base}, " + f"fa_reads={info.remote_num_fa_reads}, " + f"mamba_reads={info.remote_num_mamba_reads}, " + f"fa_sources={list(info.remote_fa_source_ranks)}, " + f"all_sources={list(info.remote_all_source_ranks)}, " + f"fa_desc_bytes={info.remote_fa_descriptor_bytes})" + ) + return f"TransferTopology({base})" + + # ============================================================ + # Private helpers + # ============================================================ + # Mamba-HMA hetero-TP transfer config: + # With hetero-TP (P_TP > D_TP), FA KV cache may be replicated across + # P ranks (when P_TP > num_kv_heads), but Mamba conv/SSM state is + # almost always uniquely sharded per P rank. So the number of P + # ranks D must read from can differ between FA and Mamba, and they + # must be handled separately. + + @staticmethod + def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range: + """Physical KV head range stored in a rank's KV cache tensor. + + When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank. + When ``tp_size > num_heads``: 1 physical head per rank. Heads are + distributed **contiguously** (matching vLLM's GQA weight partitioning): + consecutive ranks share a head before moving to the next one. + """ + if tp_size <= num_heads: + assert num_heads % tp_size == 0 + per_rank = num_heads // tp_size + return range(rank * per_rank, (rank + 1) * per_rank) + else: + h = rank * num_heads // tp_size + return range(h, h + 1) + @staticmethod + def _range_overlap(a: range, b: range) -> range: + start = max(a.start, b.start) + stop = min(a.stop, b.stop) + return range(start, max(start, stop)) -def get_current_attn_backends( - vllm_config: VllmConfig, layer_names: list[str] | None = None -) -> list[type[AttentionBackend]]: - """Get all distinct attention backends for the given layers. + # ============================================================ + # Private: build Mamba transfer info + # ============================================================ - Args: - vllm_config: The current vLLM configuration. - layer_names: Optional list of layer names to scope the lookup. - When None, all attention layers are considered. + def _build_mamba_info( + self, + remote_tp_size: int, + remote_block_size: int, + remote_block_len: int, + remote_physical_blocks_per_logical: int, + local_block_len: int, + ) -> MambaEngineTransferInfo: + """Compute Mamba transfer plan.""" + K = self.total_num_kv_heads + local_tp = self.tp_size + local_rank = self.tp_rank + + is_remote_replicated = remote_tp_size > K + remote_physical_heads = max(1, K // remote_tp_size) + + if local_tp >= remote_tp_size: + assert local_tp % remote_tp_size == 0 + tp_ratio = local_tp // remote_tp_size + else: + assert remote_tp_size % local_tp == 0 + tp_ratio = -(remote_tp_size // local_tp) - Returns: - Deduplicated list of attention backend classes. - """ - layer_type = cast(type[Any], AttentionLayerBase) - layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names) - if layers: - seen: dict[str, type[AttentionBackend]] = {} - for layer in layers.values(): - backend = layer.get_attn_backend() - seen[backend.full_cls_name()] = backend - return list(seen.values()) + abs_tp = -tp_ratio if tp_ratio < 0 else 1 - # Fallback for tests, when static_forward_context is empty. - logger.debug( - "No layers found in the vLLM config. Falling back to default attention backend." - ) - from vllm.v1.attention.selector import get_attn_backend + mamba_range: range | None = None + if tp_ratio < 0: + mamba_range = range(local_rank * abs_tp, (local_rank + 1) * abs_tp) - return [ - get_attn_backend( - head_size=vllm_config.model_config.get_head_size(), - dtype=vllm_config.model_config.dtype, - kv_cache_dtype=vllm_config.cache_config.cache_dtype, - use_mla=vllm_config.model_config.use_mla, - ) - ] + # ---- FA read targets ---- + if self.is_mla or tp_ratio >= 0: + num_fa_reads = 1 + fa_source_ranks: list[int] = ( + [0] + if self.is_mla + else [local_rank // tp_ratio if tp_ratio > 0 else local_rank] + ) + else: + local_needs = self._physical_head_range(local_tp, K, local_rank) + search_range = ( + mamba_range if mamba_range is not None else range(remote_tp_size) + ) + seen: set[tuple[int, int]] = set() + fa_source_ranks = [] + for p in search_range: + p_has = self._physical_head_range(remote_tp_size, K, p) + ov = self._range_overlap(local_needs, p_has) + if len(ov) > 0: + key = (ov.start, ov.stop) + if key not in seen: + seen.add(key) + fa_source_ranks.append(p) + if not fa_source_ranks: + for p in range(remote_tp_size): + p_has = self._physical_head_range(remote_tp_size, K, p) + ov = self._range_overlap(local_needs, p_has) + if len(ov) > 0: + key = (ov.start, ov.stop) + if key not in seen: + seen.add(key) + fa_source_ranks.append(p) + num_fa_reads = len(fa_source_ranks) + # ---- All source ranks (mamba + FA) ---- + if mamba_range is not None and abs_tp > num_fa_reads: + num_mamba_reads = abs_tp + all_source_ranks = list(mamba_range) + else: + num_mamba_reads = num_fa_reads + all_source_ranks = list(fa_source_ranks) -def get_current_attn_backend( - vllm_config: VllmConfig, layer_names: list[str] | None = None -) -> type[AttentionBackend]: - """Get the first attention backend for the given layers.""" - return get_current_attn_backends(vllm_config, layer_names)[0] + # ---- FA descriptor bytes ---- + effective_block_len = min(local_block_len, remote_block_len) + if self.is_kv_layout_blocks_first: + fa_descriptor_bytes = effective_block_len // 2 + else: + fa_descriptor_bytes = effective_block_len + # ---- Validation ---- + is_local_replicated = local_tp > K + if is_local_replicated and is_remote_replicated and tp_ratio > 0: + logger.info( + "Both-replicated hetero-TP: local_tp=%d > remote_tp=%d > K=%d.", + local_tp, + remote_tp_size, + K, + ) + tt_set = set(all_source_ranks) + for t in fa_source_ranks: + if t not in tt_set: + logger.error( + "FA source rank %d NOT in all_source_ranks %s.", + t, + all_source_ranks, + ) + if self.is_kv_layout_blocks_first and tp_ratio < 0 and num_fa_reads > 0: + local_k_half = local_block_len // 2 + remote_k_half = remote_block_len // 2 + expected = local_k_half // num_fa_reads + if expected != remote_k_half: + logger.warning( + "FA size mismatch: local_k_half=%d / reads=%d = %d, " + "but remote_k_half=%d.", + local_k_half, + num_fa_reads, + expected, + remote_k_half, + ) -# TODO (ZhanqiuHu): Consolidate TpKVTopology and HeteroTPTransferConfig -# into a single engine-agnostic TransferTopology class. -# 6 of 9 HeteroTPTransferConfig init fields duplicate TpKVTopology data. -# -# @dataclass -# class EngineTransferInfo: -# """Per-remote-engine transfer state, computed at handshake.""" -# p_tp: int -# tp_ratio: int -# p_block_len: int -# block_size: int -# # Mamba-specific (None for non-mamba models) -# fa_read_targets: list[int] | None = None -# transfer_targets: list[int] | None = None -# physical_fa_num_reads: int | None = None -# mamba_num_reads: int | None = None -# fa_entry_size: int | None = None -# -# class TransferTopology: -# """Single source of truth for TP topology + transfer sizing.""" -# # Shared (set once at init, replaces duplicate fields) -# tp_rank: int # == TpKVTopology.tp_rank == HeteroTP.d_rank -# tp_size: int # == TpKVTopology.tp_size == HeteroTP.d_tp -# total_num_kv_heads: int # == HeteroTP.K -# is_mla: bool # == HeteroTP.use_mla -# is_mamba: bool -# is_blocks_first: bool # == HeteroTP.is_blocks_first -# d_block_len: int -# -# # Per-engine (populated via register_engine() at handshake) -# _engines: dict[EngineId, EngineTransferInfo] -# -# def register_engine(self, engine_id, p_tp, p_block_len, ...): ... -# -# # General (from TpKVTopology) -# def tp_ratio(self, engine_id) -> int: ... -# def target_remote_ranks(self, engine_id) -> list[int]: ... -# def is_kv_replicated(self, engine_id) -> bool: ... -# -# # Mamba-specific (from HeteroTPTransferConfig, gated by is_mamba) -# def fa_rank_offset(self, engine_id, block_len) -> int: ... -# def physical_fa_num_reads(self, engine_id) -> int: ... -# def transfer_targets(self, engine_id) -> list[int]: ... -# def should_skip_fa(self, engine_id, p_rank) -> bool: ... -# def filter_block_ids_for_rank(self, engine_id, ...) -> ...: ... + return MambaEngineTransferInfo( + remote_tp_size=remote_tp_size, + remote_block_len=remote_block_len, + remote_block_size=remote_block_size, + remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical), + remote_fa_source_ranks=tuple(fa_source_ranks), + remote_all_source_ranks=tuple(all_source_ranks), + remote_num_fa_reads=num_fa_reads, + remote_num_mamba_reads=num_mamba_reads, + remote_fa_descriptor_bytes=fa_descriptor_bytes, + is_remote_replicated=is_remote_replicated, + remote_physical_heads=remote_physical_heads, + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index 67603e10ff60..2057c79fa58c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -21,7 +21,7 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import ( EngineId, - TpKVTopology, + TransferTopology, get_current_attn_backend, get_current_attn_backends, ) @@ -764,13 +764,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.debug("Detected kv cache layout %s", self.kv_cache_layout) self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_size} - self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} - self.kv_topo = TpKVTopology( + self.transfer_topo = TransferTopology( tp_rank=self.tp_rank, + tp_size=self.tp_size, + block_size=self.block_size, engine_id=self.engine_id, - remote_tp_size=self._tp_size, # shared state - remote_block_size=self._block_size, # shared state is_mla=self.use_mla, + is_mamba=False, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=[backend], ) @@ -911,7 +911,7 @@ async def send_kv_to_decode( self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata ): pending_reqs: dict[ReqId, SendBlockMeta] = {} - remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size) + remote_tp_ranks = self.transfer_topo.handshake_target_ranks(meta.remote_tp_size) if meta.remote_tp_rank not in remote_tp_ranks: # This D worker does not pair with the P worker. msg = ( @@ -1256,7 +1256,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): seen_base_addresses = [] self.block_len_per_layer = [] - split_k_and_v = self.kv_topo.split_k_and_v + split_k_and_v = self.transfer_topo.split_k_and_v tensor_size_bytes = None for layer_name, cache_or_caches in kv_caches.items(): cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] @@ -1495,8 +1495,8 @@ def receive_kv( remote_engine_id: EngineId, pull_metas: dict[ReqId, PullReqMeta], ): - remote_tp_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( - remote_engine_id + remote_tp_ranks = self.transfer_topo.handshake_target_ranks( + self._tp_size[remote_engine_id] ) count = len(remote_tp_ranks) logger.debug( @@ -1587,7 +1587,7 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata): ) def _producer_cache_is_replicated(self) -> bool: - return self.kv_topo.replicates_kv_cache(self.engine_id) + return self.transfer_topo.local_replicates_kv_cache def _get_transfer_regions( self, base_addrs: list[int], block_lens: list[int] @@ -1595,7 +1595,7 @@ def _get_transfer_regions( return _expand_transfer_regions( base_addrs=base_addrs, block_lens=block_lens, - is_kv_layout_blocks_first=self.kv_topo.is_kv_layout_blocks_first, + is_kv_layout_blocks_first=self.transfer_topo.is_kv_layout_blocks_first, ) def _get_sender_transfer_plan( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 45aa33033e76..3724a773cd18 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -21,8 +21,8 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( BlockIds, EngineId, - HeteroTPTransferConfig, - TpKVTopology, + MambaEngineTransferInfo, + TransferTopology, get_current_attn_backends, kv_postprocess_blksize_and_layout_on_receive, kv_postprocess_blksize_on_receive, @@ -51,7 +51,7 @@ ) from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( MambaConvSplitInfo, - compute_mamba_phys_ratio, + compute_physical_blocks_per_logical, derive_mamba_conv_split, ) from vllm.distributed.parallel_state import ( @@ -270,14 +270,12 @@ def __init__( self._registered_descs: list[Any] = [] # ---- Mamba-HMA per-engine state (only used when self._has_mamba) ---- - # Per-engine transfer config (source of truth for FA/mamba sizing). - self._transfer_configs: dict[str, HeteroTPTransferConfig] = {} - # NOTE (ZhanqiuHu): _mamba_phys_ratio MUST be per-engine. - # compute_mamba_phys_ratio = ceil((conv_bytes + ssm_bytes) / block_len) + # NOTE (ZhanqiuHu): _physical_blocks_per_logical MUST be per-engine. + # physical_blocks_per_logical = ceil((conv_bytes + ssm_bytes) / block_len) # where conv/ssm bytes are per-TP-rank (dimension-sharded). With # heterogeneous TP the per-rank sizes differ, so the ratio differs: # e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261. - self._mamba_phys_ratio: dict[EngineId, int] = {} + self._physical_blocks_per_logical: dict[EngineId, int] = {} # In progress transfers. # [req_id -> list[handle]] @@ -323,10 +321,8 @@ def __init__( # lazy initialized in register_kv_caches self.compat_hash: str | None = None - self.kv_topo: TpKVTopology | None = None + self.transfer_topo: TransferTopology | None = None - self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} - self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} # With heterogeneous TP, P must wait for all assigned D TP workers to # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) @@ -356,7 +352,6 @@ def _sync_block_size_with_kernel(self) -> None: self.block_size // kernel_block_size ) self.block_size = kernel_block_size - self._block_size[self.engine_id] = kernel_block_size self.num_blocks *= self._physical_blocks_per_logical_kv_block def _nixl_handshake( @@ -385,8 +380,8 @@ def _nixl_handshake( # Regardless, only handshake with the remote TP rank(s) that current # local rank will read from. Note that With homogeneous TP, # this happens to be the same single rank_i. - assert self.kv_topo is not None - p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size) + assert self.transfer_topo is not None + p_remote_ranks = self.transfer_topo.handshake_target_ranks(remote_tp_size) remote_rank_to_agent_name = {} path = make_zmq_path("tcp", host, port) @@ -650,11 +645,11 @@ def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None: def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" - self.kv_topo = TpKVTopology( + self.transfer_topo = TransferTopology( tp_rank=self.tp_rank, + tp_size=self.world_size, + block_size=self.block_size, engine_id=self.engine_id, - remote_tp_size=self._tp_size, # shared state - remote_block_size=self._block_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backends=self.attn_backends, @@ -665,7 +660,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): is_mamba=self._has_mamba, ) self.compat_hash = compute_nixl_compatibility_hash( - self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks + self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks ) if self.use_host_buffer: @@ -717,7 +712,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if isinstance(layer_spec, UniformTypeKVCacheSpecs): # MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs layer_spec = layer_spec.kv_cache_specs[layer_name] - cache_list = self.kv_topo.get_transfer_cache_regions( + cache_list = self.transfer_topo.get_transfer_cache_regions( cache_or_caches, layer_spec ) # `layer_spec.page_size_bytes` only accounts for logical page_size, that is @@ -730,7 +725,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) # For when registering multiple tensors eg K/V in separate regions. physical_page_size = physical_page_size // len(cache_list) - if self.kv_topo._cross_layers_blocks: + if self.transfer_topo._cross_layers_blocks: # When cross-layers blocks are used, multiply by number of layers physical_page_size = physical_page_size * len( self.kv_cache_config.kv_cache_tensors @@ -794,7 +789,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses self.num_regions = len(caches_data) - if self.kv_topo.is_kv_layout_blocks_first: + if self.transfer_topo.is_kv_layout_blocks_first: # NOTE (NickLucche) When FlashInfer is used, memory is registered # with joint KV for each block. This minimizes the overhead in # registerMem allowing faster descs queries. In order to be able to @@ -818,7 +813,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.dst_num_blocks[self.engine_id] = self.num_blocks if self._has_mamba: - self._mamba_phys_ratio[self.engine_id] = ( + self._physical_blocks_per_logical[self.engine_id] = ( self._physical_blocks_per_logical_kv_block ) logger.info( @@ -877,11 +872,13 @@ def _build_mamba_local( conv_offsets = self._conv_decomp.local_conv_offsets conv_size, ssm_size = self._mamba_ssm_size num_blocks = self._logical_num_blocks * block_size_ratio - phys_ratio = self._physical_blocks_per_logical_kv_block + physical_per_logical = self._physical_blocks_per_logical_kv_block result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(base_addresses): - page_stride = self.block_len_per_layer[i] // block_size_ratio * phys_ratio + page_stride = ( + self.block_len_per_layer[i] // block_size_ratio * physical_per_logical + ) for off, sz in conv_offsets: for blk in range(num_blocks): result.append( @@ -901,14 +898,14 @@ def _build_mamba_local( def _build_fa_remote_for_mamba( self, nixl_agent_meta: NixlAgentMetadata, - transfer_cfg: HeteroTPTransferConfig, block_size_ratio: int, - kv_topo: TpKVTopology, + transfer_topo: TransferTopology, + remote_engine_id: EngineId, ) -> list[tuple[int, int, int]]: """Build remote FA descriptors for mamba models. - Uses transfer_cfg for GQA-aware FA divisor and head-based rank offset - instead of the standard uniform tp_ratio split. + Uses TransferTopology for GQA-aware FA divisor and head-based rank + offset instead of the standard uniform tp_ratio split. """ assert block_size_ratio == 1, ( "Mamba 3-read transfer with block_size_ratio != 1 is not tested. " @@ -916,7 +913,9 @@ def _build_fa_remote_for_mamba( ) # TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA # hetero-TP logic stabilizes. - tp_ratio = transfer_cfg.tp_ratio + mamba_info = transfer_topo.get_engine_info(remote_engine_id) + assert isinstance(mamba_info, MambaEngineTransferInfo) + tp_ratio = transfer_topo.tp_ratio(mamba_info.remote_tp_size) result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): local_block_len = self.get_backend_aware_kv_block_len( @@ -927,9 +926,11 @@ def _build_fa_remote_for_mamba( local_block_len = remote_kv_block_len if tp_ratio < 0 and not self.use_mla: - local_block_len = local_block_len // transfer_cfg.physical_fa_num_reads + local_block_len = local_block_len // mamba_info.remote_num_fa_reads - rank_offset = transfer_cfg.fa_rank_offset(remote_kv_block_len) + rank_offset = transfer_topo.fa_rank_offset( + remote_engine_id, remote_kv_block_len + ) num_blocks = nixl_agent_meta.num_blocks page_size = nixl_agent_meta.block_lens[i] @@ -938,12 +939,12 @@ def _build_fa_remote_for_mamba( addr = base_addr + block_offset + rank_offset result.append((addr, local_block_len, nixl_agent_meta.device_id)) - if kv_topo.is_kv_layout_blocks_first: + if transfer_topo.is_kv_layout_blocks_first: second_split = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=False, mamba_view=False ) if tp_ratio < 0 and not self.use_mla: - second_split = second_split // transfer_cfg.physical_fa_num_reads + second_split = second_split // mamba_info.remote_num_fa_reads for block_id in range(num_blocks): block_offset = block_id * page_size addr = base_addr + block_offset + rank_offset @@ -982,15 +983,17 @@ def _build_mamba_remote( conv_offsets = [(0, xb_p), (xb_p, bb_p), (xb_p + bb_p, bb_p)] ssm_read_size = nixl_agent_meta.ssm_sizes[1] - remote_ratio = self._mamba_phys_ratio[nixl_agent_meta.engine_id] - num_blocks = nixl_agent_meta.num_blocks // remote_ratio + remote_physical_per_logical = self._physical_blocks_per_logical[ + nixl_agent_meta.engine_id + ] + num_blocks = nixl_agent_meta.num_blocks // remote_physical_per_logical device_id = nixl_agent_meta.device_id result: list[tuple[int, int, int]] = [] # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case # block lengths vary across layers (e.g. MLA). for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - page_stride = nixl_agent_meta.block_lens[i] * remote_ratio + page_stride = nixl_agent_meta.block_lens[i] * remote_physical_per_logical for off, sz in conv_offsets: for blk in range(num_blocks): result.append((base_addr + blk * page_stride + off, sz, device_id)) @@ -1020,8 +1023,8 @@ def register_local_xfer_handler( register another local_xfer_handler using remote block len to ensure data copy correctness. """ - assert self.kv_topo is not None - kv_topo = self.kv_topo + assert self.transfer_topo is not None + transfer_topo = self.transfer_topo block_size_ratio = self.block_size // block_size blocks_data: list[tuple[int, int, int]] = [] @@ -1052,7 +1055,7 @@ def register_blocks(blocks_data: list[tuple[int, int, int]], mamba: bool): # (addr, len, device id) blocks_data.append((addr, kv_block_len, self.device_id)) - if kv_topo.is_kv_layout_blocks_first: + if transfer_topo.is_kv_layout_blocks_first: second_split = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=False, mamba_view=mamba ) @@ -1154,11 +1157,29 @@ def add_remote_agent( ) return self._remote_agents[engine_id][remote_tp_rank] - ### Register remote agent metadata - if engine_id not in self._tp_size: - self._tp_size[engine_id] = remote_tp_size - if engine_id not in self._block_size: - self._block_size[engine_id] = nixl_agent_meta.block_size + ### Register remote engine in TransferTopology (idempotent). + assert self.transfer_topo is not None + transfer_topo = self.transfer_topo + physical_blocks_per_logical = ( + compute_physical_blocks_per_logical( + nixl_agent_meta.ssm_sizes, + nixl_agent_meta.block_lens[0], + ) + if self._has_mamba + else 1 + ) + transfer_topo.register_remote_engine( + remote_engine_id=engine_id, + remote_tp_size=remote_tp_size, + remote_block_size=nixl_agent_meta.block_size, + remote_block_len=nixl_agent_meta.block_lens[0], + remote_physical_blocks_per_logical=physical_blocks_per_logical, + local_block_len=self.block_len_per_layer[0], + ) + if self._has_mamba and engine_id not in self._physical_blocks_per_logical: + self._physical_blocks_per_logical[engine_id] = physical_blocks_per_logical + + logger.info("Transfer plan: %s", transfer_topo.describe(engine_id)) remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata @@ -1171,16 +1192,10 @@ def add_remote_agent( # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| # local origin:| 0| 1| 8| 12| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| - assert self.kv_topo is not None - kv_topo = self.kv_topo - block_size_ratio = kv_topo.block_size_ratio_from_engine_id(engine_id) + block_size_ratio = transfer_topo.block_size_ratio(nixl_agent_meta.block_size) if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks - if self._has_mamba: - self._mamba_phys_ratio[engine_id] = compute_mamba_phys_ratio( - nixl_agent_meta.ssm_sizes, nixl_agent_meta.block_lens[0] - ) # Keep track of remote agent kv caches base addresses. self.kv_caches_base_addr[engine_id][remote_tp_rank] = ( @@ -1190,28 +1205,13 @@ def add_remote_agent( # This is 1 when P and D `--tensor-parallel-size` match. Otherwise, # this is the ratio between the two sizes. - tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id) + tp_ratio = transfer_topo.tp_ratio(remote_tp_size) # Handle tp_size>num_kv_heads: replicate KV cache. indexes_into_remote = ( - not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0 + not transfer_topo.replicates_kv_cache(engine_id) and tp_ratio > 0 ) - # Create transfer config (single source of truth for descriptor sizes). - if self._has_mamba and engine_id not in self._transfer_configs: - self._transfer_configs[engine_id] = HeteroTPTransferConfig( - tp_ratio=tp_ratio, - K=kv_topo.total_num_kv_heads, - d_tp=self.world_size, - p_tp=remote_tp_size, - d_rank=self.tp_rank, - use_mla=self.use_mla, - d_block_len=self.block_len_per_layer[0], - p_block_len=nixl_agent_meta.block_lens[0], - is_blocks_first=kv_topo.is_kv_layout_blocks_first, - ) - logger.info("Created %s", self._transfer_configs[engine_id].describe()) - logger.debug( "Registering remote agent (%s, rank %s) memory regions with tp_ratio %s", engine_id, @@ -1232,12 +1232,10 @@ def add_remote_agent( self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] if self._has_mamba: - transfer_cfg = self._transfer_configs.get(engine_id) - assert transfer_cfg is not None - if transfer_cfg.needs_split_handles: + if transfer_topo.needs_split_handles(engine_id): # Mamba-HMA: FA and Mamba use different split factors. - for handle_data in transfer_cfg.compute_split_handle_data( - self.src_blocks_data, self.num_descs, abs_tp + for handle_data in transfer_topo.compute_split_handle_data( + engine_id, self.src_blocks_data, self.num_descs, abs_tp ): descs = self.nixl_wrapper.get_xfer_descs( handle_data, self.nixl_memory_type @@ -1248,12 +1246,8 @@ def add_remote_agent( self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) logger.info( - "Mamba-HMA split handles: targets=%s, fa_reads=%s, " - "fa_entry=%s, mamba_reads=%s, num_descs=%s", - transfer_cfg.transfer_targets, - transfer_cfg.physical_fa_num_reads, - transfer_cfg.fa_entry_size, - transfer_cfg.mamba_num_reads, + "Mamba-HMA split handles: %s, num_descs=%s", + transfer_topo.describe(engine_id), self.num_descs, ) else: @@ -1322,7 +1316,7 @@ def register_remote_blocks( (addr, local_block_len, nixl_agent_meta.device_id) ) - if kv_topo.is_kv_layout_blocks_first: + if transfer_topo.is_kv_layout_blocks_first: # With FlashInfer index V separately to allow head splitting. second_split = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=False, mamba_view=mamba @@ -1361,14 +1355,12 @@ def register_remote_blocks( engine_id, remote_tp_rank, ) - transfer_cfg = self._transfer_configs.get(engine_id) - assert transfer_cfg is not None blocks_data.extend( self._build_fa_remote_for_mamba( nixl_agent_meta, - transfer_cfg, block_size_ratio, - kv_topo, + transfer_topo, + engine_id, ) ) blocks_data.extend( @@ -1404,18 +1396,19 @@ def _validate_remote_agent_handshake( """ remote_engine_id = nixl_agent_meta.engine_id - assert self._tp_size[remote_engine_id] == remote_tp_size - assert self.kv_topo is not None + assert self.transfer_topo is not None + remote_info = self.transfer_topo.get_engine_info(remote_engine_id) + assert remote_info.remote_tp_size == remote_tp_size - tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) - block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( - remote_engine_id + tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size) + block_size_ratio = self.transfer_topo.block_size_ratio( + nixl_agent_meta.block_size ) # num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba. # Mamba models can have replicated FA KV with tp_ratio < 0. if not self._has_mamba: assert not ( - tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id) + tp_ratio < 0 and self.transfer_topo.is_kv_replicated(remote_engine_id) ) if self._is_hma_required: @@ -1468,7 +1461,7 @@ def _validate_remote_agent_handshake( if ( abs(tp_ratio) != 1 and not self.use_mla - and not self.kv_topo.is_kv_replicated(remote_engine_id) + and not self.transfer_topo.is_kv_replicated(remote_engine_id) and kv_cache_layout != "HND" and not self.enable_permute_local_kv ): @@ -1479,7 +1472,7 @@ def _validate_remote_agent_handshake( # Block len can only vary across layers when using MLA. remote_block_len = nixl_agent_meta.block_lens[0] - if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id): + if self.use_mla or self.transfer_topo.is_kv_replicated(remote_engine_id): # With replicated KV cache, only the number of blocks can differ. # TODO (ZhanqiuHu): For mamba models, validate FA and mamba # block_lens separately. @@ -1595,7 +1588,7 @@ def post_process_device_kv_on_receive( if len(self.device_kv_caches) == 0: return assert block_size_ratio >= 1, "Only nP < nD supported currently." - assert self.kv_topo is not None + assert self.transfer_topo is not None if self.enable_permute_local_kv and block_size_ratio > 1: logger.debug( "Post-processing device kv cache on receive by converting " @@ -1615,7 +1608,7 @@ def post_process_device_kv_on_receive( block_size_ratio, ) - split_k_and_v = self.kv_topo.split_k_and_v + split_k_and_v = self.transfer_topo.split_k_and_v for block_ids in block_ids_list: indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long) @@ -1662,7 +1655,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: The scheduler process (via the MultiprocExecutor) will use this output to track which workers are done. """ - assert self.kv_topo is not None + assert self.transfer_topo is not None done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) @@ -1690,8 +1683,9 @@ def get_finished(self) -> tuple[set[str], set[str]]: self.sync_recved_kv_to_device(req_id, meta) # post processing for heteroblocksize - block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( - meta.remote.engine_id + remote_info = self.transfer_topo.get_engine_info(meta.remote.engine_id) + block_size_ratio = self.transfer_topo.block_size_ratio( + remote_info.remote_block_size ) if not self.use_mla and ( block_size_ratio > 1 or self.enable_permute_local_kv @@ -1742,7 +1736,7 @@ def _get_new_notifs(self) -> set[str]: are reading from the same producer (heterogeneous TP scenario), wait for all consumers to be done pulling. """ - assert self.kv_topo is not None + assert self.transfer_topo is not None notified_req_ids: set[str] = set() for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: @@ -1761,7 +1755,7 @@ def _get_new_notifs(self) -> set[str]: # NOTE: `tp_ratio` is the opposite when swapping local<>remote n_consumers = int(tp_size) - tp_ratio = self.kv_topo.tp_ratio(n_consumers) + tp_ratio = self.transfer_topo.tp_ratio(n_consumers) # Number of reads *per producer* to wait for. # When remote D TP > local P TP we expect `tp_ratio` reads. @@ -1902,17 +1896,17 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): self._reqs_to_send[req_id] = expiration_time def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): - assert meta.remote is not None and self.kv_topo is not None - remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( - meta.remote.engine_id - ) - tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id) + assert meta.remote is not None and self.transfer_topo is not None + engine_id = meta.remote.engine_id + remote_ranks = self.transfer_topo.target_remote_ranks(engine_id) + remote_info = self.transfer_topo.get_engine_info(engine_id) + tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size) if self._has_mamba: # Expand remote logical → kernel block IDs. meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( meta.remote.block_ids, - self._mamba_phys_ratio[meta.remote.engine_id], + self._physical_blocks_per_logical[meta.remote.engine_id], ) else: meta.remote.block_ids = self._logical_to_kernel_block_ids( @@ -1925,7 +1919,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): # the first remote rank (cache is duplicated).. break - remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id] + remote_block_size = remote_info.remote_block_size logger.debug( "Remote agent %s available, calling _read_blocks" " on remote rank %s with remote block size %s for req %s", @@ -1956,9 +1950,8 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_ids: BlockIds = meta.remote.block_ids if self._has_mamba: # Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets. - transfer_cfg = self._transfer_configs.get(meta.remote.engine_id) - assert transfer_cfg is not None - local_ids, remote_ids = transfer_cfg.filter_block_ids_for_rank( + local_ids, remote_ids = self.transfer_topo.filter_block_ids_for_rank( + engine_id, remote_rank, local_ids, remote_ids, @@ -2000,8 +1993,11 @@ def _read_blocks( Post a READ point-to-point xfer request from a single local worker to a single remote worker. """ - assert self.kv_topo is not None - block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) + assert self.transfer_topo is not None + remote_info = self.transfer_topo.get_engine_info(dst_engine_id) + block_size_ratio = self.transfer_topo.block_size_ratio( + remote_info.remote_block_size + ) if block_size_ratio > 1: # TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups. assert not self._is_hma_required @@ -2191,8 +2187,8 @@ def _get_block_descs_ids( # This is like having two "low-level views" of the same storage. # `num_fa_descs` offset must be computed per-engine since P and D can # have different num_blocks (and thus different FA descs counts). - ratio = self._mamba_phys_ratio[engine_id] - logical_blocks = num_blocks // ratio + physical_per_logical = self._physical_blocks_per_logical[engine_id] + logical_blocks = num_blocks // physical_per_logical num_fa_descs = self.num_regions * num_blocks # 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm). mamba_region_ids = np.arange(len(self.block_len_per_layer) * 4)[:, None] @@ -2235,21 +2231,22 @@ def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: ] def _logical_to_remote_kernel_block_ids( - self, block_ids: BlockIds, remote_ratio: int + self, block_ids: BlockIds, remote_physical_per_logical: int ) -> BlockIds: """Map logical block IDs to physical kernel block IDs on the remote. Args: block_ids: per-group lists of logical block IDs. - remote_ratio: remote engine's physical blocks per logical block. + remote_physical_per_logical: remote engine's physical blocks + per logical block. Returns: Same structure with FA groups expanded (each logical block L - becomes kernel blocks [L*remote_ratio .. L*remote_ratio + - local_ratio - 1]). Mamba groups are passed through unchanged. + becomes kernel blocks [L*ratio .. L*ratio + local_ratio - 1]). + Mamba groups are passed through unchanged. """ local_ratio = self._physical_blocks_per_logical_kv_block - if remote_ratio == 1: + if remote_physical_per_logical == 1: return block_ids local_arange = np.arange(local_ratio).reshape(1, -1) group_specs = self.kv_cache_config.kv_cache_groups @@ -2257,7 +2254,7 @@ def _logical_to_remote_kernel_block_ids( for i, group in enumerate(block_ids): if not isinstance(group_specs[i].kv_cache_spec, MambaSpec): arr = np.array(group).reshape(-1, 1) - expanded = (arr * remote_ratio + local_arange).flatten() + expanded = (arr * remote_physical_per_logical + local_arange).flatten() result.append(expanded.tolist()) else: # Mamba blocks are 1:1 logical-to-physical (no expansion). @@ -2297,8 +2294,8 @@ def get_backend_aware_kv_block_len( +-------------------+ +--------------------+ |1st_split-2nd_split| |1st_split-2nd_split | """ - assert self.kv_topo is not None - if self.kv_topo.is_kv_layout_blocks_first: + assert self.transfer_topo is not None + if self.transfer_topo.is_kv_layout_blocks_first: # For indexing only half (either just the K or V part). if mamba_view: # NOTE (NickLucche) Mamba Opt: this is already skipping the padding so diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py index c8a5e10344bd..309426814c68 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py @@ -151,7 +151,9 @@ def derive_mamba_conv_split( ) -def compute_mamba_phys_ratio(ssm_sizes: tuple[int, ...], block_len: int) -> int: +def compute_physical_blocks_per_logical( + ssm_sizes: tuple[int, ...], block_len: int +) -> int: """Derive _physical_blocks_per_logical_kv_block from remote metadata. The remote engine's ratio is not sent directly in the handshake, so we