diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index a2a46684bb7a..c5784d1c2002 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1063,6 +1063,87 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( # whole block is moved. worker.add_remote_agent(meta, remote_tp_size=1) + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper", + FakeNixlWrapper, + ) + def test_handshake_mixed_fa_mla_hetero_tp(self, default_vllm_config, dist_init): + """Mixed full-attn (SPLIT) + MLA (REPLICATE) single KV group under + heterogeneous TP must NOT raise (previously a NotImplementedError), + and the per-region gate must still reject a wrong block_len. + """ + vllm_config = create_vllm_config() + with patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.get_tensor_model_parallel_world_size", # noqa: E501 + return_value=2, + ): + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + # Region 0: full-attn (SPLIT). Region 1: MLA (REPLICATE). + fa_len = 4096 * worker.block_size + idx_len = 512 * worker.block_size + worker.slot_size_per_layer = [4096, 512] + worker.block_len_per_layer = [fa_len, idx_len] + worker._region_is_mla = [False, True] + worker.num_blocks = 1 + worker.dst_num_blocks[worker.engine_id] = worker.num_blocks + worker.src_blocks_data = [ + (0, fa_len, worker.tp_rank), + (0, idx_len, worker.tp_rank), + ] + worker.num_descs = len(worker.src_blocks_data) + + # D_TP=2, P_TP=1 -> tp_ratio=2. SPLIT region scales by tp_ratio; + # REPLICATE region is unchanged. + tp_ratio = 2 + meta = NixlAgentMetadata( + engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0, 0], + device_id=0, + num_blocks=1, + block_lens=[fa_len * tp_ratio, idx_len], + kv_cache_layout=worker.kv_cache_layout, + block_size=worker.block_size, + ssm_sizes=(0, 0), + attn_backend_name=worker.backend_name, + physical_blocks_per_logical_kv_block=1, + ) + worker.add_remote_agent(meta, remote_tp_size=1) + assert ( + FakeNixlConnectorWorker.REMOTE_ENGINE_ID in worker.dst_xfer_side_handles + ) + # Gate rejects an MLA region wrongly scaled by tp_ratio. + worker2 = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker2.block_len_per_layer = [fa_len, idx_len] + worker2._region_is_mla = [False, True] + worker2.num_blocks = 1 + worker2.dst_num_blocks[worker2.engine_id] = worker2.num_blocks + bad_meta = NixlAgentMetadata( + engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0, 0], + device_id=0, + num_blocks=1, + # WRONG: MLA region scaled by tp_ratio (it should be replicated). + block_lens=[fa_len * tp_ratio, idx_len * tp_ratio], + kv_cache_layout=worker2.kv_cache_layout, + block_size=worker2.block_size, + ssm_sizes=(0, 0), + attn_backend_name=worker2.backend_name, + physical_blocks_per_logical_kv_block=1, + ) + with pytest.raises(AssertionError): + worker2.add_remote_agent(bad_meta, remote_tp_size=1) + # NOTE: resource cleanup in mp backend is a bit finicky, so the order in which # we put here is important. First run ray, it will clean up the resources, then diff --git a/tests/v1/kv_connector/unit/test_tp_mapping.py b/tests/v1/kv_connector/unit/test_tp_mapping.py index 95d49faf042f..5ab6b68400c9 100644 --- a/tests/v1/kv_connector/unit/test_tp_mapping.py +++ b/tests/v1/kv_connector/unit/test_tp_mapping.py @@ -73,9 +73,19 @@ def test_source_ranks_p_gt_d(self): def _make_mock_worker_for_splits(group_spec_types): - """Build a mock NixlConnectorWorker with _group_spec_types for split tests.""" + """Build a mock NixlConnectorWorker with _group_spec_types for split tests. + + No per-region replicate flags are configured (``block_len_per_layer`` empty + and ``num_regions == 0``), so ``_fa_desc_replicated`` takes its early-return + path and treats every FA descriptor as SPLIT, matching the legacy behavior + these tests assert. + """ worker = object.__new__(NixlConnectorWorker) worker._group_spec_types = group_spec_types + worker.transfer_topo = SimpleNamespace(virtually_split_kv_in_blocks=False) + worker.block_len_per_layer = [] + worker.num_regions = 0 + worker._region_is_mla = [] return worker diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index e4b20c01f4de..213a3b031446 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -71,6 +71,7 @@ from vllm.v1.kv_cache_interface import ( FullAttentionSpec, MambaSpec, + MLAAttentionSpec, UniformTypeKVCacheSpecs, ) from vllm.v1.worker.block_table import BlockTable @@ -178,19 +179,63 @@ def _build_local_splits_from_plan( else 0 ) + # Per-FA-descriptor replicate flag, in _build_fa_local emission order. + fa_desc_replicated = self._fa_desc_replicated(num_fa_descs) + for p_idx, p_rank in enumerate(plan.all_source_ranks): fa_slot = plan.rank_to_attention_slot.get(p_rank, 0) handle: list[tuple[int, int, int]] = [] for j, (addr, local_len, dev) in enumerate(src_blocks_data): if j < num_fa_descs: - chunk = local_len // fa_num_splits - handle.append((addr + fa_slot * chunk, chunk, dev)) + if fa_desc_replicated[j]: + # REPLICATE (MLA): whole block written on every rank. + handle.append((addr, local_len, dev)) + else: + # SPLIT (full-attn): this rank's head slice. + chunk = local_len // fa_num_splits + handle.append((addr + fa_slot * chunk, chunk, dev)) else: chunk = local_len // ssm_num_splits handle.append((addr + p_idx * chunk, chunk, dev)) yield handle + def _fa_desc_replicated(self, num_fa_descs: int) -> list[bool]: + """Per-FA-descriptor replicate flag, in _build_fa_local emission order + (region-major; K then optional V per region). Length ``num_fa_descs``. + """ + assert self.transfer_topo is not None + n_regions = len(self.block_len_per_layer) + # Unset only when the worker is built directly in unit tests; a real + # model always registers regions (no-KV-cache crashes long before here). + # Fall back to all-SPLIT to preserve the pre-per-region behavior. + if n_regions == 0 or self.num_regions == 0: + return [False] * num_fa_descs + # Descriptors (blocks) per stream; all streams share the same count. + nblk = num_fa_descs // self.num_regions + virtually_split = self.transfer_topo.virtually_split_kv_in_blocks + flags: list[bool] = [] + for i in range(n_regions): + replicated = self._is_region_replicated(i) + # REPLICATE (MLA) is key-only -> 1 stream; SPLIT emits K and V + # (2 streams) under the virtually-split layout. + num_streams = 1 if replicated or not virtually_split else 2 + flags.extend([replicated] * (num_streams * nblk)) + assert len(flags) == num_fa_descs, ( + f"FA desc flags {len(flags)} != num_fa_descs {num_fa_descs}" + ) + return flags + + def _is_region_replicated(self, region_idx: int) -> bool: + """Whether region ``region_idx`` is transferred REPLICATE vs SPLIT. + + REPLICATE (MLA): identical on every rank, whole block read from one + rank at offset 0, key-only. SPLIT (full-attn): head-sharded across TP. + Defaults to SPLIT when the per-region map is unset (e.g. tests that set + block_len_per_layer without register_kv_caches). + """ + return region_idx < len(self._region_is_mla) and self._region_is_mla[region_idx] + def __init__( self, vllm_config: "VllmConfig", @@ -450,6 +495,15 @@ def __init__( for g in self.kv_cache_config.kv_cache_groups ) + # Per-region MLA flag, 1:1 with block_len_per_layer. True -> REPLICATE + # (MLA), False -> SPLIT (head-sharded full-attn). Mixed only for models + # combining both (e.g. GQA main + MLA Eagle-3 draft). + self._region_is_mla = list[bool]() + + # Enable different block lengths for different layers *only* when MLA is used. + # This is not used for SSM layers, which use the counterpart `mamba_ssm_size`. + self.block_len_per_layer = list[int]() + # Per-engine TP mappings. Generated during handshake. self.tp_mappings: dict[EngineId, TPMapping] = {} @@ -849,9 +903,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # to better exploit the memory layout (ie num_blocks is the first dim). tensor_size_bytes = None - # Enable different block lengths for different layers *only* when MLA is used. - # This is not used for SSM layers, which use the counterpart `mamba_ssm_size`. - self.block_len_per_layer = list[int]() for layer_name, cache_or_caches in xfer_buffers.items(): # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. @@ -895,8 +946,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # `page_size` accounts for physical blocks, st KVCache is always # [`num_blocks` * `page_size`] curr_tensor_size_bytes = num_blocks * physical_page_size - if tensor_size_bytes is None: - tensor_size_bytes = curr_tensor_size_bytes # TODO (NickLucche) we could eventually unify how we handle FA/FI regions, # registering a single tensor for both K/V and splitting logically like FI. @@ -920,6 +969,20 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) else: self.block_len_per_layer.append(physical_page_size) + is_mla_region = isinstance(layer_spec, MLAAttentionSpec) + self._region_is_mla.append(is_mla_region) + + # HeteroTP cannot transfer differently-sized regions, so every + # non-MLA region in a group must share one tensor size (this also + # holds for Mamba-like models). The sole exception is the DeepSeek + # MLA indexer, which sits in a UniformTypeKVCacheSpecs group at a + # different size; MLA regions are therefore exempt. + if not is_mla_region: + if tensor_size_bytes is None: + tensor_size_bytes = curr_tensor_size_bytes + assert tensor_size_bytes == curr_tensor_size_bytes, ( + "All non-MLA kv cache tensors must have the same size" + ) if cache.shape[0] != num_blocks: raise AssertionError( @@ -937,12 +1000,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): f"{self.transfer_topo.is_kv_layout_blocks_first}" ) - if not self.use_mla: - # Different kv cache shape is not supported by HeteroTP. - # This must also hold true for Mamba-like models. - assert tensor_size_bytes == curr_tensor_size_bytes, ( - "All kv cache tensors must have the same size" - ) # Need to make sure the device ID is non-negative for NIXL, # Torch uses -1 to indicate CPU tensors. self.device_id = max(cache.get_device(), 0) @@ -953,7 +1010,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug( "Different block lengths collected: %s", set(self.block_len_per_layer) ) - assert len(self.block_len_per_layer) == len(seen_base_addresses) + assert ( + len(self.block_len_per_layer) + == len(seen_base_addresses) + == len(self._region_is_mla) + ) self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses self.num_regions = len(caches_data) @@ -967,7 +1028,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # of 'virtual' regions here and halve `block_len` below. # Similarly for Mamba layers, we register SSM+Conv as a single region and # then duplicate it logically to be able to index SSM/Conv separately. - self.num_regions *= 2 + # Exception: key-only REPLICATE regions (MLA) have no V half, so + # they contribute a single desc stream and are not doubled. + self.num_regions = sum( + 1 if self._is_region_replicated(i) else 2 + for i in range(len(self._region_is_mla)) + ) # Total local FA descriptors (boundary between FA and mamba descs). self.num_descs = self.num_regions * self.num_blocks @@ -1133,10 +1199,13 @@ def _build_fa_local( addr = base_addr + block_offset result.append((addr, kv_block_len, self.device_id)) - if self.transfer_topo.virtually_split_kv_in_blocks: + if ( + self.transfer_topo.virtually_split_kv_in_blocks + and not self._is_region_replicated(i) + ): # Separate and interleave K/V regions to maintain the same # descs ordering. This is needed for selecting contiguous heads - # when split across TP ranks. + # when split across TP ranks. (Skipped for key-only REPLICATE.) second_split = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=False, mamba_view=False ) @@ -1158,10 +1227,13 @@ def _build_fa_remote( fa_group_idx = next( i for i, t in enumerate(self._group_spec_types) if _is_attention_spec(t) ) - num_attn_reads = len(plan.source_ranks_per_group[fa_group_idx]) + # SPLIT regions read their head slice from this many remote ranks at a + # per-rank offset; REPLICATE regions read the whole block once. + split_reads = len(plan.source_ranks_per_group[fa_group_idx]) num_blocks = nixl_agent_meta.num_blocks result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + replicated = self._is_region_replicated(i) # Read our whole local region size from remote.. local_block_len = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=True, mamba_view=False @@ -1171,8 +1243,13 @@ def _build_fa_remote( # ..using remote kv_block_len as transfer unit local_block_len = remote_kv_block_len - local_block_len = local_block_len // num_attn_reads - rank_offset = plan.rank_offset_factor * remote_kv_block_len + # REPLICATE reads the whole block once at offset 0; SPLIT gathers + # its head slice from `split_reads` remote ranks at a per-rank offset. + num_reads = 1 if replicated else split_reads + rank_offset = ( + 0 if replicated else plan.rank_offset_factor * remote_kv_block_len + ) + local_block_len = local_block_len // num_reads page_size = nixl_agent_meta.block_lens[i] for block_id in range(num_blocks): @@ -1182,12 +1259,13 @@ def _build_fa_remote( addr = base_addr + block_offset + rank_offset result.append((addr, local_block_len, nixl_agent_meta.device_id)) - if self.transfer_topo.virtually_split_kv_in_blocks: + emits_v = self.transfer_topo.virtually_split_kv_in_blocks and not replicated + if emits_v: # With FlashInfer index V separately to allow head splitting. second_split = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=False, mamba_view=False ) - second_split = second_split // num_attn_reads + second_split = second_split // num_reads for block_id in range(num_blocks): block_offset = block_id * page_size addr = base_addr + block_offset + rank_offset @@ -1527,49 +1605,43 @@ def _validate_remote_agent_handshake( "Use HND layout on the prefill side." ) - # Block len can only vary across layers when using MLA. - remote_block_len = nixl_agent_meta.block_lens[0] - if self.use_mla or self.transfer_topo.is_kv_replicated(remote_engine_id): - # With replicated KV cache, only the number of blocks can differ. - # TODO (ZhanqiuHu): For mamba models, validate FA and mamba - # block_lens separately. - if not self._has_mamba: - for i in range(len(self.block_len_per_layer)): - assert ( - self.block_len_per_layer[i] // block_size_ratio - == nixl_agent_meta.block_lens[i] - ), "KV cache sizes must match between P and D when replicated" - else: - # When MLA is not used, this is a list of the same block length - for block_len in nixl_agent_meta.block_lens: - assert block_len == remote_block_len, ( - "All remote layers must have the same block size" - ) - - # HMA hybrid models (mamba+attention) pad block_len to - # max(attn_page, mamba_page), so the linear tp_ratio scaling - # assumption only holds for pure-attention models. - if not self._has_mamba: - if tp_ratio > 0: - assert ( - remote_block_len - == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio - ), ( - "Remote P worker KV layer cache must be of shape [2, N," - " local_kv_heads*tp_ratio, page_size, head_dim] and " - "same dtype." + # Per-region block_len validation enforcing the P/D invariant. + # REPLICATE regions (MLA, or a whole-model MLA / replicated-KV transfer) + # only allow the number of blocks to differ; SPLIT regions scale with + # tp_ratio. Mamba uses the ssm_sizes counterpart, so skip block_len here. + if not self._has_mamba: + assert len(self.block_len_per_layer) == len(nixl_agent_meta.block_lens), ( + "Number of KV layers must match between prefill and decode" + ) + model_replicated = self.use_mla or self.transfer_topo.is_kv_replicated( + remote_engine_id + ) + for i, local_len in enumerate(self.block_len_per_layer): + replicated = model_replicated or self._is_region_replicated(i) + remote_len = nixl_agent_meta.block_lens[i] + if replicated: + # Whole block copied; only the number of blocks may differ. + assert local_len // block_size_ratio == remote_len, ( + "KV cache sizes must match between P and D when " + f"replicated (region {i}: local={local_len}, " + f"remote={remote_len}, bsr={block_size_ratio})." + ) + elif tp_ratio > 0: + # D_TP >= P_TP: remote holds tp_ratio x local heads. + assert remote_len == (local_len * tp_ratio) // block_size_ratio, ( + f"SPLIT region {i}: remote P KV block_len {remote_len} " + f"must equal local {local_len} * tp_ratio {tp_ratio} " + f"// block_size_ratio {block_size_ratio}." ) else: + # P_TP > D_TP: local holds |tp_ratio| x remote heads. assert block_size_ratio == 1, ( - "Different local/remote block sizes are not supported" - " when P TP > D TP." + "Different local/remote block sizes are not supported " + "when P TP > D TP." ) - assert remote_block_len == self.block_len_per_layer[0] // ( - -tp_ratio - ), ( - "Remote P worker KV layer cache must be of shape [2, N," - " local_kv_heads/tp_ratio, page_size, head_dim] and " - "same dtype." + assert remote_len == local_len // (-tp_ratio), ( + f"SPLIT region {i}: remote P KV block_len {remote_len} " + f"must equal local {local_len} // |tp_ratio| {-tp_ratio}." ) # TP workers that handhshake with same remote have same #blocks. @@ -2450,13 +2522,16 @@ def get_backend_aware_kv_block_len( |1st_split-2nd_split| |1st_split-2nd_split | """ assert self.transfer_topo is not None - if self.transfer_topo.virtually_split_kv_in_blocks: - if mamba_view: - block_len = self._mamba_ssm_size[not first_split] - else: - block_len = self.block_len_per_layer[layer_idx] // 2 + virtually_split = self.transfer_topo.virtually_split_kv_in_blocks + if virtually_split and mamba_view: + block_len = self._mamba_ssm_size[not first_split] else: - block_len = self.block_len_per_layer[layer_idx] + # Per-descriptor block length: a SPLIT region (full-attn under the + # virtually-split layout) emits separate K and V and uses + # block_len//2; REPLICATE (MLA, key-only) and non-split layouts use + # the whole block. + half_block = virtually_split and not self._is_region_replicated(layer_idx) + block_len = self.block_len_per_layer[layer_idx] // (2 if half_block else 1) return block_len def get_kv_connector_stats(self) -> KVConnectorStats | None: