From 75825f5ebdbaec224ef515d2fbf0fe5389d9970f Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Fri, 5 Jun 2026 19:00:07 +0800 Subject: [PATCH 01/36] try fix with token instead of bytes Signed-off-by: JaredforReal --- .../kv_transfer/kv_connector/utils.py | 22 ++++-- .../kv_connector/v1/nixl/worker.py | 67 +++++++++++++++---- 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 0ab694b7e73d..f21b6bb07e82 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -533,12 +533,24 @@ def tp_ratio(self, remote_tp_size: int) -> int: return -(remote_tp_size // self.tp_size) def block_size_ratio(self, remote_block_size: int) -> int: - """Calculate the block size ratio between local and remote.""" - assert self.block_size % remote_block_size == 0, ( - f"Local block size {self.block_size} is not divisible " - f"by remote block size {remote_block_size} or vice versa." + """Calculate the block size ratio between local and remote. + + Positive when local >= remote (local blocks are larger). + Negative when remote > local (remote blocks are larger). + """ + if self.block_size == remote_block_size: + return 1 + if self.block_size > remote_block_size: + assert self.block_size % remote_block_size == 0, ( + f"Local block size {self.block_size} is not divisible " + f"by remote block size {remote_block_size}." + ) + return self.block_size // remote_block_size + assert remote_block_size % self.block_size == 0, ( + f"Remote block size {remote_block_size} is not divisible " + f"by local block size {self.block_size}." ) - return self.block_size // remote_block_size + return -(remote_block_size // self.block_size) def is_kv_replicated( self, remote_engine_id: EngineId, remote_pp_rank: int = 0 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index a297058c845e..d57e02abd456 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1151,9 +1151,13 @@ def _build_fa_remote( local_block_len = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=True, mamba_view=False ) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - # ..using remote kv_block_len as transfer unit + # Use actual remote block length from metadata for correct + # byte-size computation. Deriving from local / block_size_ratio + # is wrong for heterogeneous TP where local and remote have + # different byte-per-block values (e.g. hybrid MLA+GDN models). + remote_kv_block_len = nixl_agent_meta.block_lens[i] + if local_block_len > remote_kv_block_len: + # Remote has smaller blocks, use remote's size as transfer unit local_block_len = remote_kv_block_len local_block_len = local_block_len // num_attn_reads @@ -1305,6 +1309,7 @@ def add_remote_agent( remote_tp_size=remote_tp_size, group_spec_types=self._group_spec_types, ) + self._raise_if_hma_remote_block_size_mismatch(nixl_agent_meta.block_size) remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata @@ -1317,7 +1322,32 @@ def add_remote_agent( # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| # local origin:| 0| 1| 8| 12| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| - block_size_ratio = transfer_topo.block_size_ratio(nixl_agent_meta.block_size) + # Compute block_size_ratio from actual byte-per-block values so that + # heterogeneous TP works even when block_size carries byte values + # (e.g. hybrid MLA+GDN models where block_size differs across TP + # configs). _build_fa_remote already uses nixl_agent_meta.block_lens + # directly, so this ratio is only needed for handler registration and + # descriptor-ID computation. + if ( + self.block_len_per_layer + and nixl_agent_meta.block_lens + and self.block_len_per_layer[0] != nixl_agent_meta.block_lens[0] + ): + local_bytes = self.block_len_per_layer[0] + remote_bytes = nixl_agent_meta.block_lens[0] + if local_bytes > remote_bytes and local_bytes % remote_bytes == 0: + block_size_ratio = local_bytes // remote_bytes + elif remote_bytes > local_bytes and remote_bytes % local_bytes == 0: + block_size_ratio = -(remote_bytes // local_bytes) + else: + # Non-exact byte division (e.g. hybrid models with + # TP-independent MLA component). Use 1 as fallback; + # _build_fa_remote handles bytes via remote block_lens. + block_size_ratio = 1 + else: + block_size_ratio = transfer_topo.block_size_ratio( + nixl_agent_meta.block_size + ) if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks @@ -1425,9 +1455,14 @@ def _validate_remote_agent_handshake( assert remote_info.remote_tp_size == remote_tp_size tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size) - block_size_ratio = self.transfer_topo.block_size_ratio( - nixl_agent_meta.block_size - ) + try: + block_size_ratio = self.transfer_topo.block_size_ratio( + nixl_agent_meta.block_size + ) + except AssertionError: + # Heterogeneous TP with non-divisible block sizes (e.g. hybrid + # MLA+GDN). Use 1 as a safe fallback for validation checks. + block_size_ratio = 1 # num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba. # Mamba models can have replicated FA KV with tp_ratio < 0. # MLA models do not need to handle kv replication. @@ -1744,9 +1779,12 @@ def get_finished(self) -> tuple[set[str], set[str]]: # post processing for heteroblocksize remote_info = self.transfer_topo.get_engine_info(meta.remote.engine_id) - block_size_ratio = self.transfer_topo.block_size_ratio( - remote_info.remote_block_size - ) + try: + block_size_ratio = self.transfer_topo.block_size_ratio( + remote_info.remote_block_size + ) + except AssertionError: + block_size_ratio = 1 if not self.use_mla and ( block_size_ratio > 1 or self.enable_permute_local_kv ): @@ -2122,9 +2160,12 @@ def _read_blocks( remote_block_ids = read_spec.remote_block_ids remote_info = self.transfer_topo.get_engine_info(dst_engine_id) - block_size_ratio = self.transfer_topo.block_size_ratio( - remote_info.remote_block_size - ) + try: + block_size_ratio = self.transfer_topo.block_size_ratio( + remote_info.remote_block_size + ) + except AssertionError: + block_size_ratio = 1 if block_size_ratio > 1: # TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups. assert not self._is_hma_required From de8591a7cccd1a3ade93adc9cb35d57d15a77194 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Sat, 6 Jun 2026 11:08:29 +0800 Subject: [PATCH 02/36] delete undefigned variable Signed-off-by: JaredforReal --- vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index d57e02abd456..a1826f6644ce 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1309,7 +1309,6 @@ def add_remote_agent( remote_tp_size=remote_tp_size, group_spec_types=self._group_spec_types, ) - self._raise_if_hma_remote_block_size_mismatch(nixl_agent_meta.block_size) remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata From 256c1fd61b8b092b5a82ab59b79f962da431216d Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Sat, 6 Jun 2026 11:52:34 +0800 Subject: [PATCH 03/36] fix nixl handshake Signed-off-by: JaredforReal --- .../kv_transfer/kv_connector/v1/nixl/worker.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index a1826f6644ce..5740ce5bd4d8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1151,13 +1151,9 @@ def _build_fa_remote( local_block_len = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=True, mamba_view=False ) - # Use actual remote block length from metadata for correct - # byte-size computation. Deriving from local / block_size_ratio - # is wrong for heterogeneous TP where local and remote have - # different byte-per-block values (e.g. hybrid MLA+GDN models). - remote_kv_block_len = nixl_agent_meta.block_lens[i] - if local_block_len > remote_kv_block_len: - # Remote has smaller blocks, use remote's size as transfer unit + remote_kv_block_len = local_block_len // block_size_ratio + if block_size_ratio > 1: + # ..using remote kv_block_len as transfer unit local_block_len = remote_kv_block_len local_block_len = local_block_len // num_attn_reads From 8e048f57d57a4893c2f8e6f671f84c302b9fee5c Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 14:33:01 +0800 Subject: [PATCH 04/36] split MLA SSM regions Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 54 ++++++++++++++++--- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 5740ce5bd4d8..d9b0e239e682 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -95,7 +95,7 @@ def _compute_desc_ids( ) -> np.ndarray: """Compute NIXL descriptor IDs for given block IDs.""" num_fa_regions = self.num_regions - num_ssm_regions = len(self.block_len_per_layer) * 4 if self._has_mamba else 0 + num_ssm_regions = sum(self._is_ssm_region) * 4 if self._has_mamba else 0 num_blocks = dst_num_blocks if block_size_ratio is not None: @@ -365,6 +365,10 @@ def __init__( # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) self.num_regions = 0 + # Per-region flag: True if the region is an SSM/Mamba layer. + # Populated during register_kv_caches; used to route FA descriptors + # to attention regions and Mamba descriptors to SSM regions. + self._is_ssm_region: list[bool] = [] # nixl_prepped_dlist_handle. self.src_xfer_handles_by_block_size: dict[int, int] = {} @@ -844,6 +848,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Enable different block lengths for different layers *only* when MLA is used. # This is not used for SSM layers, which use the counterpart `mamba_ssm_size`. self.block_len_per_layer = list[int]() + self._is_ssm_region = list[bool]() for layer_name, cache_or_caches in xfer_buffers.items(): # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. @@ -898,8 +903,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): "Registering layer %s with cache shape: %s", layer_name, cache.shape ) seen_base_addresses.append(base_addr) + is_ssm = isinstance(layer_spec, MambaSpec) + self._is_ssm_region.append(is_ssm) # Only record non-Mamba page sizes. - if isinstance(layer_spec, MambaSpec): + if is_ssm: self.block_len_per_layer.append( physical_page_size // self._physical_blocks_per_logical_kv_block ) @@ -941,7 +948,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert len(self.block_len_per_layer) == len(seen_base_addresses) self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses - self.num_regions = len(caches_data) + # FA regions count only attention layers. SSM/Mamba regions are + # served by Mamba descriptors, not FA descriptors, and must be + # excluded here so that FA descriptor IDs do not reference them. + # (For hybrid MLA+GDN models, building FA descriptors for the GDN + # region with virtual K/V split would produce out-of-bounds addresses + # under heterogeneous TP, failing NIXL prepXferDlist.) + num_attention_base = len(caches_data) - sum(self._is_ssm_region) + self.num_regions = num_attention_base if self.transfer_topo.virtually_split_kv_in_blocks: # NOTE (NickLucche) When FlashInfer is used, memory is registered @@ -1030,6 +1044,12 @@ def _build_mamba_local( result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(base_addresses): + # Only build Mamba descriptors for SSM/Mamba regions. + # Attention/MLA regions do not contain conv or temporal state; + # building Mamba descriptors for them would reference addresses + # outside their registered memory, causing prepXferDlist failures. + if not self._is_ssm_region[i]: + continue # Jump one page_size, but ssm page_size may be bigger when kernel # locks block size to a specific value (physical_per_logical scale). page_stride = ( @@ -1081,6 +1101,10 @@ def _build_mamba_remote( # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case # block lengths vary across layers (e.g. MLA). for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + # Only build Mamba descriptors for SSM/Mamba regions. + # Attention/MLA regions do not contain conv or temporal state. + if i < len(self._is_ssm_region) and not self._is_ssm_region[i]: + continue page_stride = nixl_agent_meta.block_lens[i] * remote_physical_per_logical for off, sz in conv_offsets: for blk in range(num_blocks): @@ -1106,6 +1130,10 @@ def _build_fa_local( num_blocks = self.num_blocks * block_size_ratio result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(base_addresses): + # Only build FA descriptors for attention regions. + # SSM/Mamba regions are served by Mamba descriptors. + if i < len(self._is_ssm_region) and self._is_ssm_region[i]: + continue kv_block_len = ( self.get_backend_aware_kv_block_len( layer_idx=i, first_split=True, mamba_view=False @@ -1147,6 +1175,10 @@ def _build_fa_remote( num_blocks = nixl_agent_meta.num_blocks result: list[tuple[int, int, int]] = [] for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + # Only build FA descriptors for attention regions. + # SSM/Mamba regions are served by Mamba descriptors. + if i < len(self._is_ssm_region) and self._is_ssm_region[i]: + continue # Read our whole local region size from remote.. local_block_len = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=True, mamba_view=False @@ -1367,9 +1399,12 @@ def add_remote_agent( plan = self.tp_mappings[engine_id] ### (Optional) Register local agent memory regions. MLA is not split. + # For hybrid MLA+GDN models, SSM state is TP-sharded and must be split + # to assemble data from multiple remote P ranks even when MLA + # attention is replicated. if ( tp_ratio < 0 - and not self.use_mla + and (not self.use_mla or self._has_mamba) and tp_ratio not in self.src_xfer_handles_by_tp_ratio ): # Remote tp_size > local tp_size: read from multiple remote ranks. @@ -2087,7 +2122,10 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): # D may have to perform multiple reads from different remote ranks. # MLA opt: when P TP > D TP, only a single read is executed for # the first remote rank (cache is duplicated).. - if self.use_mla and tp_ratio < 0: + # For hybrid MLA+GDN models, SSM state is TP-sharded, so multiple + # remote ranks are still needed for the SSM group even when MLA + # attention only needs one rank. + if self.use_mla and tp_ratio < 0 and not self._has_mamba: assert len(read_specs) == 1 for i, spec in enumerate(read_specs): @@ -2101,7 +2139,11 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): req_id, ) # Get side handles. - if tp_ratio < 0 and not self.use_mla: + # For hybrid MLA+GDN with tp_ratio < 0, SSM needs split handles to + # assemble data from multiple remote ranks. MLA attention reads + # use the full region (replicated, single rank) but the split + # handle applies offset 0 + full chunk for FA when fa_num_splits=1. + if tp_ratio < 0 and (not self.use_mla or self._has_mamba): assert remote_block_size == self.block_size # Remote tp_size > local tp_size: we must perform multiple # reads. Get the memory chunk onto which we will write to. From d8a4ad59e6f00478a7290d23a6cab3d7093881e8 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 14:40:17 +0800 Subject: [PATCH 05/36] remove block size assert Signed-off-by: JaredforReal --- vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index d9b0e239e682..c59e8562a568 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2144,7 +2144,6 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): # use the full region (replicated, single rank) but the split # handle applies offset 0 + full chunk for FA when fa_num_splits=1. if tp_ratio < 0 and (not self.use_mla or self._has_mamba): - assert remote_block_size == self.block_size # Remote tp_size > local tp_size: we must perform multiple # reads. Get the memory chunk onto which we will write to. local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i] From 1987f572293728d2ddd3adf1de50d421b3b3b05f Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 15:25:18 +0800 Subject: [PATCH 06/36] fix P --- .../kv_transfer/kv_connector/v1/nixl/worker.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index c59e8562a568..1cddfe99f8eb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2150,9 +2150,12 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): else: # Single read from remote, we write to the whole memory region. # Also handle remote block size different from local block size. - local_xfer_side_handle = self.src_xfer_handles_by_block_size[ - remote_block_size - ] + # Use remote block_size handle if registered (block_size_ratio > 1), + # otherwise fall back to local block_size handle. + local_xfer_side_handle = self.src_xfer_handles_by_block_size.get( + remote_block_size, + self.src_xfer_handles_by_block_size[self.block_size], + ) # Destination handle: remote_engine_id -> remote_rank -> handle. remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][ From 62868b259c6fc67cdb821e75b92f68004b30c24b Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 15:54:00 +0800 Subject: [PATCH 07/36] notify right ranks Signed-off-by: JaredforReal --- .../kv_transfer/kv_connector/v1/nixl/worker.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 1cddfe99f8eb..2e6f46a314a4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2174,11 +2174,19 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): if self.use_mla and tp_ratio < 0 and read_specs: # ..but we still need to notify the other remote ranks that we # have the blocks we need so they can update the request state. + # Only notify ranks that we actually read from (all_source_ranks), + # not all remote ranks — ranks we didn't read from don't track + # this request and would log "unrecognized request" errors. notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() remote_agents = self._remote_agents[meta.remote.engine_id] - for rank_to_notify, agent in remote_agents.items(): - if rank_to_notify != read_specs[0].remote_rank: - self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) + for rank_to_notify in plan.all_source_ranks: + if ( + rank_to_notify != read_specs[0].remote_rank + and rank_to_notify in remote_agents + ): + self.nixl_wrapper.send_notif( + remote_agents[rank_to_notify], notif_msg=notif_id + ) def _read_blocks( self, From ebe73852e705e62819ff27a8a45fc86ccac2bab9 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 18:14:50 +0800 Subject: [PATCH 08/36] store Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 2e6f46a314a4..c745b5f5bad4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2172,21 +2172,24 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): ) if self.use_mla and tp_ratio < 0 and read_specs: - # ..but we still need to notify the other remote ranks that we - # have the blocks we need so they can update the request state. - # Only notify ranks that we actually read from (all_source_ranks), - # not all remote ranks — ranks we didn't read from don't track - # this request and would log "unrecognized request" errors. - notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() - remote_agents = self._remote_agents[meta.remote.engine_id] - for rank_to_notify in plan.all_source_ranks: - if ( - rank_to_notify != read_specs[0].remote_rank - and rank_to_notify in remote_agents - ): - self.nixl_wrapper.send_notif( - remote_agents[rank_to_notify], notif_msg=notif_id - ) + # MLA attention KV is replicated across P ranks. D reads from + # one P rank, but other P ranks also hold the same KV and need + # to be notified so they can release blocks early (vs. timeout). + # + # For hybrid MLA+GDN models, skip MLA notification entirely. + # MLA attention has only one source rank (replicated), so no + # *other* MLA source needs notifying. SSM source ranks are + # already notified via the NIXL read's built-in notif_msg, and + # sending extra MLA notifications to them causes "unrecognized + # request" errors because the P-side request tracking dicts + # (_reqs_to_send / _reqs_to_process) are only populated on the + # decode side. + if not self._has_mamba: + notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() + remote_agents = self._remote_agents[meta.remote.engine_id] + for rank_to_notify, agent in remote_agents.items(): + if rank_to_notify != read_specs[0].remote_rank: + self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) def _read_blocks( self, From 59d26de8208ef74354451deb0be765254079061d Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 19:47:58 +0800 Subject: [PATCH 09/36] add debug logging Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 314 +++++++++++++++++- 1 file changed, 313 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index c745b5f5bad4..539082d1da4e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -82,6 +82,10 @@ logger = init_logger(__name__) +# Diagnostic flag for descriptor debugging. Set to True to enable verbose +# logging of KV cache registration, descriptor building, and transfer data. +_DESC_DEBUG = True + class NixlConnectorWorker: """Implementation of Worker side methods""" @@ -102,6 +106,24 @@ def _compute_desc_ids( num_blocks = int(num_blocks * block_size_ratio) num_fa_descs = num_fa_regions * num_blocks + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] _compute_desc_ids: num_fa_regions=%s, " + "num_ssm_regions=%s, dst_num_blocks=%s, block_size_ratio=%s, " + "physical_blocks_per_logical=%s => num_blocks=%s, " + "num_fa_descs=%s (FA/Mamba boundary), " + "block_ids_lengths=%s, group_spec_types=%s", + num_fa_regions, + num_ssm_regions, + dst_num_blocks, + block_size_ratio, + physical_blocks_per_logical, + num_blocks, + num_fa_descs, + [len(g) for g in block_ids], + [t.__name__ for t in self._group_spec_types], + ) + # All-attention fast path: single vectorized broadcast. if num_ssm_regions == 0: # NOTE (NickLucche) With HMA, every kv group has the same number of layers @@ -147,7 +169,17 @@ def _compute_desc_ids( f"Unknown spec type {self._group_spec_types[i]} at index {i}" ) - return np.concatenate(all_descs) + result = np.concatenate(all_descs) + if _DESC_DEBUG: + sample = result[:10].tolist() if len(result) > 10 else result.tolist() + logger.warning( + "[DESC-DEBUG] _compute_desc_ids result: total=%d, " + "per_group_counts=%s, sample_ids=%s", + len(result), + [len(d) for d in all_descs], + sample, + ) + return result def _build_local_splits_from_plan( self, @@ -947,6 +979,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) assert len(self.block_len_per_layer) == len(seen_base_addresses) + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] register_kv_caches summary: " + "num_unique_regions=%d, is_ssm_region=%s, " + "block_len_per_layer=%s, group_spec_types=%s, " + "has_mamba=%s, use_mla=%s, " + "physical_blocks_per_logical=%s, logical_num_blocks=%s, " + "num_blocks=%s", + len(seen_base_addresses), + self._is_ssm_region, + self.block_len_per_layer, + [t.__name__ for t in self._group_spec_types], + self._has_mamba, + self.use_mla, + self._physical_blocks_per_logical_kv_block, + self._logical_num_blocks, + self.num_blocks, + ) + self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses # FA regions count only attention layers. SSM/Mamba regions are # served by Mamba descriptors, not FA descriptors, and must be @@ -957,6 +1008,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): num_attention_base = len(caches_data) - sum(self._is_ssm_region) self.num_regions = num_attention_base + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] Region counts: total_caches_data=%d, " + "num_ssm_regions=%d, num_attention_base=%d, " + "num_regions_before_split=%d", + len(caches_data), + sum(self._is_ssm_region), + num_attention_base, + self.num_regions, + ) + if self.transfer_topo.virtually_split_kv_in_blocks: # NOTE (NickLucche) When FlashInfer is used, memory is registered # with joint KV for each block. This minimizes the overhead in @@ -1049,7 +1111,28 @@ def _build_mamba_local( # building Mamba descriptors for them would reference addresses # outside their registered memory, causing prepXferDlist failures. if not self._is_ssm_region[i]: + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] _build_mamba_local: SKIP region %d " + "(not SSM), base=0x%x", + i, + base_addr, + ) continue + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] _build_mamba_local: BUILD region %d " + "(SSM), base=0x%x, page_stride=%d, conv_offsets=%s, " + "ssm_size=%d, num_blocks=%d", + i, + base_addr, + self.block_len_per_layer[i] + // block_size_ratio + * physical_per_logical, + conv_offsets, + ssm_size, + num_blocks, + ) # Jump one page_size, but ssm page_size may be bigger when kernel # locks block size to a specific value (physical_per_logical scale). page_stride = ( @@ -1104,7 +1187,28 @@ def _build_mamba_remote( # Only build Mamba descriptors for SSM/Mamba regions. # Attention/MLA regions do not contain conv or temporal state. if i < len(self._is_ssm_region) and not self._is_ssm_region[i]: + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] _build_mamba_remote: SKIP region %d " + "(not SSM), remote_base=0x%x", + i, + base_addr, + ) continue + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] _build_mamba_remote: BUILD region %d " + "(SSM), remote_base=0x%x, page_stride=%d, " + "conv_offsets=%s, ssm_read_size=%d, " + "local_offset=%d, num_blocks=%d", + i, + base_addr, + nixl_agent_meta.block_lens[i] * remote_physical_per_logical, + conv_offsets, + ssm_read_size, + local_offset, + num_blocks, + ) page_stride = nixl_agent_meta.block_lens[i] * remote_physical_per_logical for off, sz in conv_offsets: for blk in range(num_blocks): @@ -1133,6 +1237,12 @@ def _build_fa_local( # Only build FA descriptors for attention regions. # SSM/Mamba regions are served by Mamba descriptors. if i < len(self._is_ssm_region) and self._is_ssm_region[i]: + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] _build_fa_local: SKIP region %d (SSM), base=0x%x", + i, + base_addr, + ) continue kv_block_len = ( self.get_backend_aware_kv_block_len( @@ -1141,6 +1251,18 @@ def _build_fa_local( // block_size_ratio ) page_stride = self.block_len_per_layer[i] // block_size_ratio + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] _build_fa_local: BUILD region %d " + "(attn), base=0x%x, kv_block_len=%d, page_stride=%d, " + "num_blocks=%d, virtual_split=%s", + i, + base_addr, + kv_block_len, + page_stride, + num_blocks, + self.transfer_topo.virtually_split_kv_in_blocks, + ) for block_id in range(num_blocks): block_offset = block_id * page_stride addr = base_addr + block_offset @@ -1178,7 +1300,30 @@ def _build_fa_remote( # Only build FA descriptors for attention regions. # SSM/Mamba regions are served by Mamba descriptors. if i < len(self._is_ssm_region) and self._is_ssm_region[i]: + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] _build_fa_remote: SKIP region %d " + "(SSM), remote_base=0x%x", + i, + base_addr, + ) continue + if _DESC_DEBUG: + local_block_len = self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=True, mamba_view=False + ) + logger.warning( + "[DESC-DEBUG] _build_fa_remote: BUILD region %d " + "(attn), remote_base=0x%x, local_block_len=%d, " + "remote_block_len=%d, num_blocks=%d, " + "num_attn_reads=%d", + i, + base_addr, + local_block_len, + nixl_agent_meta.block_lens[i], + num_blocks, + num_attn_reads, + ) # Read our whole local region size from remote.. local_block_len = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=True, mamba_view=False @@ -1233,6 +1378,7 @@ def register_local_xfer_handler( local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank] blocks_data = self._build_fa_local(local_base_addresses, block_size_ratio) + fa_desc_count = len(blocks_data) logger.debug( "Created %s blocks for src engine %s and rank %s on device id %s", len(blocks_data), @@ -1253,6 +1399,43 @@ def register_local_xfer_handler( self._build_mamba_local(local_base_addresses, block_size_ratio) ) + if _DESC_DEBUG: + mamba_desc_count = len(blocks_data) - fa_desc_count + logger.warning( + "[DESC-DEBUG] register_local_xfer_handler: " + "fa_desc_count=%d, mamba_desc_count=%d, total=%d, " + "num_descs(boundary)=%d, num_regions=%d, num_blocks=%d, " + "is_ssm_region=%s", + fa_desc_count, + mamba_desc_count, + len(blocks_data), + self.num_descs, + self.num_regions, + self.num_blocks, + self._is_ssm_region, + ) + # Spot-check first/last FA desc and first/last Mamba desc + if blocks_data: + logger.warning( + "[DESC-DEBUG] local FA desc[0]: addr=0x%x size=%d dev=%d", + *blocks_data[0], + ) + if fa_desc_count > 1: + logger.warning( + "[DESC-DEBUG] local FA desc[%d]: addr=0x%x size=%d dev=%d", + fa_desc_count - 1, + *blocks_data[fa_desc_count - 1], + ) + if mamba_desc_count > 0: + logger.warning( + "[DESC-DEBUG] local Mamba desc[0]: addr=0x%x size=%d dev=%d", + *blocks_data[fa_desc_count], + ) + logger.warning( + "[DESC-DEBUG] local Mamba desc[last]: addr=0x%x size=%d dev=%d", + *blocks_data[-1], + ) + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs), blocks_data @@ -1435,6 +1618,7 @@ def add_remote_agent( nixl_agent_meta, block_size_ratio, ) + remote_fa_count = len(blocks_data) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", len(blocks_data), @@ -1456,6 +1640,35 @@ def add_remote_agent( ) ) + if _DESC_DEBUG: + remote_mamba_count = len(blocks_data) - remote_fa_count + logger.warning( + "[DESC-DEBUG] add_remote_agent: remote_fa_descs=%d, " + "remote_mamba_descs=%d, total_remote=%d, " + "tp_ratio=%s, block_size_ratio=%s, " + "remote_num_blocks=%s, remote_base_addrs=%d, " + "remote_block_lens=%s", + remote_fa_count, + remote_mamba_count, + len(blocks_data), + tp_ratio, + block_size_ratio, + nixl_agent_meta.num_blocks, + len(nixl_agent_meta.kv_caches_base_addr), + nixl_agent_meta.block_lens, + ) + # Spot-check remote descriptors + if blocks_data: + logger.warning( + "[DESC-DEBUG] remote FA desc[0]: addr=0x%x size=%d dev=%d", + *blocks_data[0], + ) + if remote_mamba_count > 0: + logger.warning( + "[DESC-DEBUG] remote Mamba desc[0]: addr=0x%x size=%d dev=%d", + *blocks_data[remote_fa_count], + ) + # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) self.dst_xfer_side_handles[engine_id][remote_tp_rank] = ( @@ -1795,6 +2008,10 @@ def get_finished(self) -> tuple[set[str], set[str]]: meta = self._recving_metadata.pop(req_id, None) assert meta is not None, f"{req_id} not found in recving_metadata list" + # Spot-check KV tensor values after transfer + if _DESC_DEBUG and meta.remote is not None: + self._spot_check_kv_after_recv(req_id, meta) + # Skip KV sync and post-processing for failed requests if req_id in failed_recv_reqs: logger.warning( @@ -2305,6 +2522,37 @@ def _read_blocks( assert len(local_block_descs_ids) == len(remote_block_descs_ids) + if _DESC_DEBUG: + local_sample = local_block_descs_ids[:10].tolist() + remote_sample = remote_block_descs_ids[:10].tolist() + logger.warning( + "[DESC-DEBUG] _read_blocks req=%s: num_desc_pairs=%d, " + "local_block_ids=%s, remote_block_ids=%s, " + "local_desc_sample=%s, remote_desc_sample=%s", + request_id, + len(local_block_descs_ids), + [ids[:3] for ids in local_block_ids], + [ids[:3] for ids in remote_block_ids], + local_sample, + remote_sample, + ) + # Spot-check descriptor content (addr/size) from src/dst blocks + if hasattr(self, "src_blocks_data") and self.src_blocks_data: + for check_idx in [0, -1]: + if check_idx < len(local_block_descs_ids): + did = local_block_descs_ids[check_idx] + if 0 <= did < len(self.src_blocks_data): + addr, sz, dev = self.src_blocks_data[did] + logger.warning( + "[DESC-DEBUG] local_desc[%d]=%d => " + "src_blocks_data: addr=0x%x size=%d dev=%d", + check_idx, + did, + addr, + sz, + dev, + ) + # Prepare transfer with Nixl. handle = None try: @@ -2334,6 +2582,70 @@ def _read_blocks( ) self._handle_failed_transfer(request_id, handle) + # Counter for spot-check rate limiting + _spot_check_count: int = 0 + + def _spot_check_kv_after_recv(self, req_id: str, meta: ReqMeta) -> None: + """Spot-check KV tensor values after a transfer completes.""" + NixlConnectorWorker._spot_check_count += 1 + if NixlConnectorWorker._spot_check_count > 5: + return # Only check the first few transfers + + num_kv_caches = len(self.device_kv_caches) + logger.warning( + "[DESC-DEBUG] _spot_check_kv_after_recv req=%s " + "num_kv_caches=%d local_block_ids=%s", + req_id, + num_kv_caches, + [ids[:3] for ids in meta.local_physical_block_ids], + ) + + # Check the actual tensor values for the first few cache entries + checked = 0 + for layer_name, cache_val in list(self.device_kv_caches.items())[:4]: + if isinstance(cache_val, torch.Tensor): + t = cache_val + elif isinstance(cache_val, (tuple, list)): + # Could be (conv_state, recurrent_state) for Mamba or + # (k_cache, v_cache) for attention + t = cache_val[0] if len(cache_val) > 0 else None + if t is None: + continue + else: + continue + + if not isinstance(t, torch.Tensor): + continue + + has_nan = torch.isnan(t).any().item() + has_inf = torch.isinf(t).any().item() + mean_val = t.float().mean().item() + std_val = t.float().std().item() + first_vals = t.flatten()[:4].tolist() + + logger.warning( + "[DESC-DEBUG] layer=%s shape=%s dtype=%s " + "mean=%.6f std=%.6f nan=%s inf=%s first_vals=%s", + layer_name, + tuple(t.shape), + t.dtype, + mean_val, + std_val, + has_nan, + has_inf, + first_vals, + ) + checked += 1 + + if checked == 0: + logger.warning( + "[DESC-DEBUG] No torch.Tensor found in device_kv_caches. Types: %s", + { + k: type(v).__name__ + for k, v in list(self.device_kv_caches.items())[:4] + }, + ) + def get_mapped_blocks( self, block_ids: np.ndarray, block_size_ratio: int ) -> np.ndarray: From ea8f8f0ef22fbc728a7c0bf763ef4a767262f877 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 20:08:23 +0800 Subject: [PATCH 10/36] add debug logging Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 47 +++++++++++++++++-- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 539082d1da4e..231636833026 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -891,6 +891,29 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if isinstance(layer_spec, UniformTypeKVCacheSpecs): # MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs layer_spec = layer_spec.kv_cache_specs[layer_name] + if _DESC_DEBUG: + cache_type = type(cache_or_caches).__name__ + if isinstance(cache_or_caches, torch.Tensor): + cache_info = f"Tensor(shape={tuple(cache_or_caches.shape)}, dtype={cache_or_caches.dtype})" + elif isinstance(cache_or_caches, (tuple, list)): + parts = [] + for c in cache_or_caches: + if isinstance(c, torch.Tensor): + parts.append(f"Tensor(shape={tuple(c.shape)}, ptr=0x{c.data_ptr():x})") + else: + parts.append(type(c).__name__) + cache_info = f"({', '.join(parts)})" + else: + cache_info = cache_type + logger.warning( + "[DESC-DEBUG] register layer=%s: spec=%s, is_ssm=%s, " + "cache_type=%s, cache_info=%s", + layer_name, + type(layer_spec).__name__, + isinstance(layer_spec, MambaSpec), + cache_type, + cache_info, + ) cache_list = self.transfer_topo.get_transfer_cache_regions( cache_or_caches, layer_spec ) @@ -929,11 +952,27 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # across groups. This results in skipping all tensors but the ones # pointed to by group0. Also, generally we will have more blocks # per tensor but fewer regions. - logger.debug("Skipping %s because it's already seen", layer_name) + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] SKIP layer=%s: base=0x%x already seen " + "(spec=%s, is_ssm=%s)", + layer_name, + base_addr, + type(layer_spec).__name__, + isinstance(layer_spec, MambaSpec), + ) continue - logger.debug( - "Registering layer %s with cache shape: %s", layer_name, cache.shape - ) + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] REGISTER layer=%s: base=0x%x shape=%s " + "spec=%s is_ssm=%s page_size=%d", + layer_name, + base_addr, + tuple(cache.shape), + type(layer_spec).__name__, + isinstance(layer_spec, MambaSpec), + physical_page_size, + ) seen_base_addresses.append(base_addr) is_ssm = isinstance(layer_spec, MambaSpec) self._is_ssm_region.append(is_ssm) From c73c26754c27a868c4bfaac3a1973c28f0660459 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 20:39:35 +0800 Subject: [PATCH 11/36] fix mla to be mismatched to ssm spec Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 44 +++++++++++++------ 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 231636833026..dbd413795c04 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -397,10 +397,12 @@ def __init__( # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) self.num_regions = 0 - # Per-region flag: True if the region is an SSM/Mamba layer. - # Populated during register_kv_caches; used to route FA descriptors - # to attention regions and Mamba descriptors to SSM regions. + # Per-region flags: True if the region has ANY SSM/Mamba layers + # or ANY attention/MLA layers, respectively. HMA may back both + # layer types with the same physical tensor, making a region + # dual-purpose (both flags True). self._is_ssm_region: list[bool] = [] + self._is_attn_region: list[bool] = [] # nixl_prepped_dlist_handle. self.src_xfer_handles_by_block_size: dict[int, int] = {} @@ -881,6 +883,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # This is not used for SSM layers, which use the counterpart `mamba_ssm_size`. self.block_len_per_layer = list[int]() self._is_ssm_region = list[bool]() + self._is_attn_region = list[bool]() for layer_name, cache_or_caches in xfer_buffers.items(): # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. @@ -899,7 +902,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): parts = [] for c in cache_or_caches: if isinstance(c, torch.Tensor): - parts.append(f"Tensor(shape={tuple(c.shape)}, ptr=0x{c.data_ptr():x})") + parts.append( + f"Tensor(shape={tuple(c.shape)}, ptr=0x{c.data_ptr():x})" + ) else: parts.append(type(c).__name__) cache_info = f"({', '.join(parts)})" @@ -947,19 +952,26 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # registering a single tensor for both K/V and splitting logically like FI. for cache in cache_list: base_addr = cache.data_ptr() + is_ssm = isinstance(layer_spec, MambaSpec) + is_attn = not is_ssm if base_addr in seen_base_addresses: # NOTE (NickLucche) HMA employs memory pooling to share tensors # across groups. This results in skipping all tensors but the ones - # pointed to by group0. Also, generally we will have more blocks - # per tensor but fewer regions. + # pointed to by group0. However, the same physical tensor may + # back both SSM and attention layers. Accumulate both flags so + # that the region can be dual-purpose. + idx = seen_base_addresses.index(base_addr) + self._is_ssm_region[idx] = self._is_ssm_region[idx] or is_ssm + self._is_attn_region[idx] = self._is_attn_region[idx] or is_attn if _DESC_DEBUG: logger.warning( "[DESC-DEBUG] SKIP layer=%s: base=0x%x already seen " - "(spec=%s, is_ssm=%s)", + "(spec=%s, is_ssm=%s, is_attn=%s)", layer_name, base_addr, type(layer_spec).__name__, - isinstance(layer_spec, MambaSpec), + is_ssm, + is_attn, ) continue if _DESC_DEBUG: @@ -975,7 +987,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) seen_base_addresses.append(base_addr) is_ssm = isinstance(layer_spec, MambaSpec) + is_attn = not is_ssm self._is_ssm_region.append(is_ssm) + self._is_attn_region.append(is_attn) # Only record non-Mamba page sizes. if is_ssm: self.block_len_per_layer.append( @@ -1044,7 +1058,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (For hybrid MLA+GDN models, building FA descriptors for the GDN # region with virtual K/V split would produce out-of-bounds addresses # under heterogeneous TP, failing NIXL prepXferDlist.) - num_attention_base = len(caches_data) - sum(self._is_ssm_region) + # HMA dual-purpose regions (both SSM and attn) are counted here. + num_attention_base = sum(self._is_attn_region) self.num_regions = num_attention_base if _DESC_DEBUG: @@ -1149,6 +1164,7 @@ def _build_mamba_local( # Attention/MLA regions do not contain conv or temporal state; # building Mamba descriptors for them would reference addresses # outside their registered memory, causing prepXferDlist failures. + # Dual-purpose regions (HMA) get both FA and Mamba descs. if not self._is_ssm_region[i]: if _DESC_DEBUG: logger.warning( @@ -1275,10 +1291,11 @@ def _build_fa_local( for i, base_addr in enumerate(base_addresses): # Only build FA descriptors for attention regions. # SSM/Mamba regions are served by Mamba descriptors. - if i < len(self._is_ssm_region) and self._is_ssm_region[i]: + # Dual-purpose regions (HMA) get both FA and Mamba descs. + if i < len(self._is_attn_region) and not self._is_attn_region[i]: if _DESC_DEBUG: logger.warning( - "[DESC-DEBUG] _build_fa_local: SKIP region %d (SSM), base=0x%x", + "[DESC-DEBUG] _build_fa_local: SKIP region %d (not attn), base=0x%x", i, base_addr, ) @@ -1338,11 +1355,12 @@ def _build_fa_remote( for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): # Only build FA descriptors for attention regions. # SSM/Mamba regions are served by Mamba descriptors. - if i < len(self._is_ssm_region) and self._is_ssm_region[i]: + # Dual-purpose regions (HMA) get both FA and Mamba descs. + if i < len(self._is_attn_region) and not self._is_attn_region[i]: if _DESC_DEBUG: logger.warning( "[DESC-DEBUG] _build_fa_remote: SKIP region %d " - "(SSM), remote_base=0x%x", + "(not attn), remote_base=0x%x", i, base_addr, ) From 2120f7d098cab9e6d5bc4fd00c756567548e2700 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 20:45:45 +0800 Subject: [PATCH 12/36] close debug logging Signed-off-by: JaredforReal --- vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index dbd413795c04..0d57d25cdaca 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -84,7 +84,7 @@ # Diagnostic flag for descriptor debugging. Set to True to enable verbose # logging of KV cache registration, descriptor building, and transfer data. -_DESC_DEBUG = True +_DESC_DEBUG = False class NixlConnectorWorker: From 0c0f050d2df60f13eac71f647f29c86436d7d65b Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 21:30:02 +0800 Subject: [PATCH 13/36] more debug logging Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 47 ++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 0d57d25cdaca..74fe7cb43f92 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -84,7 +84,7 @@ # Diagnostic flag for descriptor debugging. Set to True to enable verbose # logging of KV cache registration, descriptor building, and transfer data. -_DESC_DEBUG = False +_DESC_DEBUG = True class NixlConnectorWorker: @@ -2694,7 +2694,50 @@ def _spot_check_kv_after_recv(self, req_id: str, meta: ReqMeta) -> None: ) checked += 1 - if checked == 0: + # Explicitly check MLA attention layers (key contains '.attn') + # These are critical — if their data is zero/garbage after transfer, + # MLA attention will produce garbage output. + block_ids = meta.local_physical_block_ids + mla_layers = [ + (k, v) for k, v in self.device_kv_caches.items() + if ".attn" in k and isinstance(v, torch.Tensor) + ] + for layer_name, t in mla_layers[:4]: + has_nan = torch.isnan(t).any().item() + has_inf = torch.isinf(t).any().item() + mean_val = t.float().mean().item() + std_val = t.float().std().item() + first_vals = t.flatten()[:4].tolist() + # Check specific blocks that were transferred + block_stats = "" + if block_ids: + for _, bids in enumerate( + block_ids[:1] + ): + for bid in bids[:2]: + if bid < t.shape[0]: + blk = t[bid].float() + block_stats += ( + f" block[{bid}]:mean={blk.mean():.4f}" + f",std={blk.std():.4f}" + f",first4={blk.flatten()[:4].tolist()}" + ) + logger.warning( + "[DESC-DEBUG] MLA layer=%s shape=%s dtype=%s " + "global_mean=%.6f global_std=%.6f nan=%s inf=%s " + "first_vals=%s%s", + layer_name, + tuple(t.shape), + t.dtype, + mean_val, + std_val, + has_nan, + has_inf, + first_vals, + block_stats, + ) + + if checked == 0 and not mla_layers: logger.warning( "[DESC-DEBUG] No torch.Tensor found in device_kv_caches. Types: %s", { From fe2016f4419670457965ed30b9a9c94f38fd7542 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 21:41:48 +0800 Subject: [PATCH 14/36] fix debug loggig Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 88 ++++++------------- 1 file changed, 28 insertions(+), 60 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 74fe7cb43f92..fba89b0e86eb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2657,14 +2657,22 @@ def _spot_check_kv_after_recv(self, req_id: str, meta: ReqMeta) -> None: [ids[:3] for ids in meta.local_physical_block_ids], ) - # Check the actual tensor values for the first few cache entries + # Memory-efficient spot check: only examine transferred blocks, + # not the full tensor (which can be >100 GB and cause OOM). + block_ids = meta.local_physical_block_ids + + # Collect unique block IDs to check (across all groups) + bids_to_check: list[int] = [] + for bids in block_ids[:4]: + for bid in bids[:2]: + if bid not in bids_to_check: + bids_to_check.append(bid) + checked = 0 - for layer_name, cache_val in list(self.device_kv_caches.items())[:4]: + for layer_name, cache_val in list(self.device_kv_caches.items())[:8]: if isinstance(cache_val, torch.Tensor): t = cache_val elif isinstance(cache_val, (tuple, list)): - # Could be (conv_state, recurrent_state) for Mamba or - # (k_cache, v_cache) for attention t = cache_val[0] if len(cache_val) > 0 else None if t is None: continue @@ -2674,70 +2682,30 @@ def _spot_check_kv_after_recv(self, req_id: str, meta: ReqMeta) -> None: if not isinstance(t, torch.Tensor): continue - has_nan = torch.isnan(t).any().item() - has_inf = torch.isinf(t).any().item() - mean_val = t.float().mean().item() - std_val = t.float().std().item() - first_vals = t.flatten()[:4].tolist() - - logger.warning( - "[DESC-DEBUG] layer=%s shape=%s dtype=%s " - "mean=%.6f std=%.6f nan=%s inf=%s first_vals=%s", - layer_name, - tuple(t.shape), - t.dtype, - mean_val, - std_val, - has_nan, - has_inf, - first_vals, - ) - checked += 1 - - # Explicitly check MLA attention layers (key contains '.attn') - # These are critical — if their data is zero/garbage after transfer, - # MLA attention will produce garbage output. - block_ids = meta.local_physical_block_ids - mla_layers = [ - (k, v) for k, v in self.device_kv_caches.items() - if ".attn" in k and isinstance(v, torch.Tensor) - ] - for layer_name, t in mla_layers[:4]: - has_nan = torch.isnan(t).any().item() - has_inf = torch.isinf(t).any().item() - mean_val = t.float().mean().item() - std_val = t.float().std().item() - first_vals = t.flatten()[:4].tolist() - # Check specific blocks that were transferred + is_mla = ".attn" in layer_name + # Check only transferred blocks — one block is tiny (~18 KB) block_stats = "" - if block_ids: - for _, bids in enumerate( - block_ids[:1] - ): - for bid in bids[:2]: - if bid < t.shape[0]: - blk = t[bid].float() - block_stats += ( - f" block[{bid}]:mean={blk.mean():.4f}" - f",std={blk.std():.4f}" - f",first4={blk.flatten()[:4].tolist()}" - ) + for bid in bids_to_check[:3]: + if bid < t.shape[0]: + blk = t[bid].float() + block_stats += ( + f" blk[{bid}]:mean={blk.mean():.6f}" + f",std={blk.std():.6f}" + f",nan={torch.isnan(blk).any().item()}" + f",inf={torch.isinf(blk).any().item()}" + f",f4={blk.flatten()[:4].tolist()}" + ) logger.warning( - "[DESC-DEBUG] MLA layer=%s shape=%s dtype=%s " - "global_mean=%.6f global_std=%.6f nan=%s inf=%s " - "first_vals=%s%s", + "[DESC-DEBUG] %s layer=%s shape=%s dtype=%s%s", + "MLA" if is_mla else "SSM", layer_name, tuple(t.shape), t.dtype, - mean_val, - std_val, - has_nan, - has_inf, - first_vals, block_stats, ) + checked += 1 - if checked == 0 and not mla_layers: + if checked == 0: logger.warning( "[DESC-DEBUG] No torch.Tensor found in device_kv_caches. Types: %s", { From 1f4173f85f10077f80aa06c4ae76e8566e6df413 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 22:03:22 +0800 Subject: [PATCH 15/36] fix debug loggig Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index fba89b0e86eb..2b4f14385723 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2657,6 +2657,25 @@ def _spot_check_kv_after_recv(self, req_id: str, meta: ReqMeta) -> None: [ids[:3] for ids in meta.local_physical_block_ids], ) + # Dump ALL keys with types and shapes so we can find MLA entries + all_keys_info = [] + for k, v in self.device_kv_caches.items(): + if isinstance(v, torch.Tensor): + all_keys_info.append(f"{k}=Tensor{tuple(v.shape)}") + elif isinstance(v, (tuple, list)): + parts = [] + for c in v: + if isinstance(c, torch.Tensor): + parts.append(f"T{tuple(c.shape)}") + else: + parts.append(type(c).__name__) + all_keys_info.append(f"{k}=({','.join(parts)})") + else: + all_keys_info.append(f"{k}={type(v).__name__}") + logger.warning( + "[DESC-DEBUG] device_kv_caches keys: %s", all_keys_info + ) + # Memory-efficient spot check: only examine transferred blocks, # not the full tensor (which can be >100 GB and cause OOM). block_ids = meta.local_physical_block_ids From 8636426ffa5f4a419cbcf4ad9b42bb5f534b38fd Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 8 Jun 2026 22:14:39 +0800 Subject: [PATCH 16/36] more debug logging Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 55 +++++++++++-------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 2b4f14385723..a3bf4ce78d24 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2687,22 +2687,8 @@ def _spot_check_kv_after_recv(self, req_id: str, meta: ReqMeta) -> None: if bid not in bids_to_check: bids_to_check.append(bid) - checked = 0 - for layer_name, cache_val in list(self.device_kv_caches.items())[:8]: - if isinstance(cache_val, torch.Tensor): - t = cache_val - elif isinstance(cache_val, (tuple, list)): - t = cache_val[0] if len(cache_val) > 0 else None - if t is None: - continue - else: - continue - - if not isinstance(t, torch.Tensor): - continue - - is_mla = ".attn" in layer_name - # Check only transferred blocks — one block is tiny (~18 KB) + # Helper to spot-check one layer's transferred blocks + def _check_layer(label: str, name: str, t: torch.Tensor) -> None: block_stats = "" for bid in bids_to_check[:3]: if bid < t.shape[0]: @@ -2716,15 +2702,38 @@ def _spot_check_kv_after_recv(self, req_id: str, meta: ReqMeta) -> None: ) logger.warning( "[DESC-DEBUG] %s layer=%s shape=%s dtype=%s%s", - "MLA" if is_mla else "SSM", - layer_name, - tuple(t.shape), - t.dtype, - block_stats, + label, name, tuple(t.shape), t.dtype, block_stats, ) - checked += 1 - if checked == 0: + # Check first 2 KDA layers (conv state from tuple) + checked = 0 + for layer_name, cache_val in list(self.device_kv_caches.items()): + if checked >= 2: + break + if isinstance(cache_val, (tuple, list)): + t = cache_val[0] if len(cache_val) > 0 else None + if t is None or not isinstance(t, torch.Tensor): + continue + if ".attn" in layer_name: + continue # skip MLA entries here + _check_layer("SSM", layer_name, t) + checked += 1 + + # CRITICAL: Check ALL MLA layers (keys containing ".self_attn.attn") + mla_count = 0 + for layer_name, cache_val in self.device_kv_caches.items(): + if ".attn" not in layer_name: + continue + if isinstance(cache_val, torch.Tensor): + _check_layer("MLA", layer_name, cache_val) + mla_count += 1 + elif isinstance(cache_val, (tuple, list)): + for i, c in enumerate(cache_val): + if isinstance(c, torch.Tensor): + _check_layer(f"MLA[{i}]", layer_name, c) + mla_count += 1 + + if checked == 0 and mla_count == 0: logger.warning( "[DESC-DEBUG] No torch.Tensor found in device_kv_caches. Types: %s", { From 7de4e5a6c85cdceed64a6014f485d244a17fd962 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 00:34:49 +0800 Subject: [PATCH 17/36] fix with right stripe Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 164 +++++++++++++++--- 1 file changed, 142 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index a3bf4ce78d24..01f4380ad344 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -403,6 +403,11 @@ def __init__( # dual-purpose (both flags True). self._is_ssm_region: list[bool] = [] self._is_attn_region: list[bool] = [] + # For HMA dual-purpose regions: maps region_idx to the attention + # spec's physical page size (bytes). When KDA (MambaSpec) and MLA + # (AttentionSpec) share the same backing tensor, block_len_per_layer + # stores KDA's stride. FA descriptors need MLA's stride instead. + self._attn_block_len: dict[int, int] = {} # nixl_prepped_dlist_handle. self.src_xfer_handles_by_block_size: dict[int, int] = {} @@ -884,6 +889,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_len_per_layer = list[int]() self._is_ssm_region = list[bool]() self._is_attn_region = list[bool]() + self._attn_block_len = dict[int, int]() for layer_name, cache_or_caches in xfer_buffers.items(): # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. @@ -963,6 +969,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): idx = seen_base_addresses.index(base_addr) self._is_ssm_region[idx] = self._is_ssm_region[idx] or is_ssm self._is_attn_region[idx] = self._is_attn_region[idx] or is_attn + # Record the attention spec's stride so that FA descriptors + # use MLA's page size instead of KDA's for dual-purpose + # regions. + if is_attn: + self._attn_block_len[idx] = physical_page_size + if _DESC_DEBUG: + ssm_stride = self.block_len_per_layer[idx] + logger.warning( + "[DESC-DEBUG] DUAL-PURPOSE region %d: " + "ssm_block_len=%d, attn_block_len=%d, " + "stride_mismatch=%s", + idx, + ssm_stride, + physical_page_size, + ssm_stride != physical_page_size, + ) if _DESC_DEBUG: logger.warning( "[DESC-DEBUG] SKIP layer=%s: base=0x%x already seen " @@ -1033,16 +1055,26 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert len(self.block_len_per_layer) == len(seen_base_addresses) if _DESC_DEBUG: + dual_purpose = [ + i + for i in range(len(self._is_ssm_region)) + if i < len(self._is_attn_region) + and self._is_ssm_region[i] + and self._is_attn_region[i] + ] logger.warning( "[DESC-DEBUG] register_kv_caches summary: " "num_unique_regions=%d, is_ssm_region=%s, " - "block_len_per_layer=%s, group_spec_types=%s, " + "block_len_per_layer=%s, attn_block_len=%s, " + "dual_purpose_regions=%s, group_spec_types=%s, " "has_mamba=%s, use_mla=%s, " "physical_blocks_per_logical=%s, logical_num_blocks=%s, " "num_blocks=%s", len(seen_base_addresses), self._is_ssm_region, self.block_len_per_layer, + self._attn_block_len, + dual_purpose, [t.__name__ for t in self._group_spec_types], self._has_mamba, self.use_mla, @@ -1051,6 +1083,36 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.num_blocks, ) + # Expand NIXL registration size for dual-purpose regions. KDA + # registers first with its stride (542720 bytes/block), but the + # physical allocation uses max(KDA, MLA) = 552960 bytes/block. + # Without this, the registered region is too small for FA descriptors + # that address MLA data at the larger stride. + for idx in range(len(caches_data)): + if ( + idx < len(self._is_ssm_region) + and self._is_ssm_region[idx] + and idx < len(self._is_attn_region) + and self._is_attn_region[idx] + ): + max_page = max( + self.block_len_per_layer[idx], + self._attn_block_len.get(idx, 0), + ) + old_size = caches_data[idx][1] + new_size = self.num_blocks * max_page + if old_size != new_size: + base, _, device, label = caches_data[idx] + caches_data[idx] = (base, new_size, device, label) + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] DUAL-PURPOSE registration size: " + "region %d, original=%d, expanded=%d", + idx, + old_size, + new_size, + ) + self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses # FA regions count only attention layers. SSM/Mamba regions are # served by Mamba descriptors, not FA descriptors, and must be @@ -1300,13 +1362,33 @@ def _build_fa_local( base_addr, ) continue - kv_block_len = ( - self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=False + # For dual-purpose regions (HMA), block_len_per_layer stores the + # KDA/SSM stride. FA descriptors must use the attention spec's + # stride instead so they address MLA data correctly. + if i in self._attn_block_len: + attn_stride = self._attn_block_len[i] + if self.transfer_topo.virtually_split_kv_in_blocks: + kv_block_len = attn_stride // 2 + else: + kv_block_len = attn_stride + kv_block_len = kv_block_len // block_size_ratio + page_stride = attn_stride // block_size_ratio + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] _build_fa_local: DUAL-PURPOSE region %d, " + "using attn_stride=%d (ssm_stride=%d)", + i, + page_stride, + self.block_len_per_layer[i] // block_size_ratio, + ) + else: + kv_block_len = ( + self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=True, mamba_view=False + ) + // block_size_ratio ) - // block_size_ratio - ) - page_stride = self.block_len_per_layer[i] // block_size_ratio + page_stride = self.block_len_per_layer[i] // block_size_ratio if _DESC_DEBUG: logger.warning( "[DESC-DEBUG] _build_fa_local: BUILD region %d " @@ -1328,9 +1410,12 @@ def _build_fa_local( # Separate and interleave K/V regions to maintain the same # descs ordering. This is needed for selecting contiguous heads # when split across TP ranks. - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=False - ) + if i in self._attn_block_len: + second_split = self._attn_block_len[i] // 2 + else: + second_split = self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=False, mamba_view=False + ) for block_id in range(num_blocks): block_offset = block_id * page_stride addr = base_addr + block_offset @@ -1382,9 +1467,26 @@ def _build_fa_remote( num_attn_reads, ) # Read our whole local region size from remote.. - local_block_len = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=False - ) + # For dual-purpose regions, use the attention spec's stride instead + # of block_len_per_layer (which stores KDA's stride). + if i in self._attn_block_len: + attn_stride = self._attn_block_len[i] + if self.transfer_topo.virtually_split_kv_in_blocks: + local_block_len = attn_stride // 2 + else: + local_block_len = attn_stride + if _DESC_DEBUG: + logger.warning( + "[DESC-DEBUG] _build_fa_remote: DUAL-PURPOSE region %d, " + "using attn_stride=%d (remote_block_lens=%d)", + i, + attn_stride, + nixl_agent_meta.block_lens[i], + ) + else: + local_block_len = self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=True, mamba_view=False + ) remote_kv_block_len = local_block_len // block_size_ratio if block_size_ratio > 1: # ..using remote kv_block_len as transfer unit @@ -1393,7 +1495,15 @@ def _build_fa_remote( local_block_len = local_block_len // num_attn_reads rank_offset = plan.rank_offset_factor * remote_kv_block_len - page_size = nixl_agent_meta.block_lens[i] + # For dual-purpose regions, use the attention stride as page size + # so descriptors step through remote memory at MLA's stride. + # NOTE: This assumes homogeneous TP where both sides have the same + # attention spec stride. For heterogeneous TP, the remote's attn + # stride must be propagated via metadata. + if i in self._attn_block_len: + page_size = self._attn_block_len[i] + else: + page_size = nixl_agent_meta.block_lens[i] for block_id in range(num_blocks): block_offset = block_id * page_size # For each block, grab the kv heads chunk belonging to current local @@ -1403,15 +1513,23 @@ def _build_fa_remote( if self.transfer_topo.virtually_split_kv_in_blocks: # With FlashInfer index V separately to allow head splitting. - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=False - ) + if i in self._attn_block_len: + second_split = self._attn_block_len[i] // 2 + else: + second_split = self.get_backend_aware_kv_block_len( + layer_idx=i, first_split=False, mamba_view=False + ) second_split = second_split // num_attn_reads + v_stride = ( + self._attn_block_len[i] + if i in self._attn_block_len + else nixl_agent_meta.block_lens[i] + ) for block_id in range(num_blocks): block_offset = block_id * page_size addr = base_addr + block_offset + rank_offset # Hop over the first split of remote page, K, to read V. - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + v_addr = addr + v_stride // 2 result.append((v_addr, second_split, nixl_agent_meta.device_id)) return result @@ -2672,9 +2790,7 @@ def _spot_check_kv_after_recv(self, req_id: str, meta: ReqMeta) -> None: all_keys_info.append(f"{k}=({','.join(parts)})") else: all_keys_info.append(f"{k}={type(v).__name__}") - logger.warning( - "[DESC-DEBUG] device_kv_caches keys: %s", all_keys_info - ) + logger.warning("[DESC-DEBUG] device_kv_caches keys: %s", all_keys_info) # Memory-efficient spot check: only examine transferred blocks, # not the full tensor (which can be >100 GB and cause OOM). @@ -2702,7 +2818,11 @@ def _check_layer(label: str, name: str, t: torch.Tensor) -> None: ) logger.warning( "[DESC-DEBUG] %s layer=%s shape=%s dtype=%s%s", - label, name, tuple(t.shape), t.dtype, block_stats, + label, + name, + tuple(t.shape), + t.dtype, + block_stats, ) # Check first 2 KDA layers (conv state from tuple) From 69d0bc62d8610e8986078811b22ca4de9b3de00f Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 00:56:00 +0800 Subject: [PATCH 18/36] more Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/README.md | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/nixl/README.md diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/README.md b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/README.md new file mode 100644 index 000000000000..eda6293efdf6 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/README.md @@ -0,0 +1,60 @@ +# NIXL Connector + +Source location: `vllm/distributed/kv_transfer/kv_connector/v1/nixl`. + +This directory implements `NixlConnector`, the main high-performance +Prefill/Decode disaggregation connector. It handles the v1 scheduler/worker +split, NIXL agent metadata, heterogeneous TP mapping, HMA/hybrid SSM support, +transfer failure reporting, lease/heartbeat behavior, and connector metrics. + +## Key source files + +- `connector.py`: facade class `NixlConnector`; delegates scheduler role to + `NixlConnectorScheduler` and worker role to `NixlConnectorWorker`. +- `scheduler.py`: request lifecycle on the scheduler side: new requests, remote + hit accounting, allocation updates, connector metadata, finished requests, + lease and timeout state. +- `worker.py`: NIXL worker implementation: memory registration, handshake, + remote-agent tracking, transfer reads/writes, HMA handling, prefix-cache + post-processing, heartbeats, and completion reporting. +- `metadata.py`: wire and in-process metadata structures, including + `NixlAgentMetadata`, `NixlHandshakePayload`, `RemoteMeta`, `ReqMeta`, and + `NixlConnectorMetadata`. +- `tp_mapping.py`: local-to-remote TP mapping for heterogeneous deployments. + This is the first file to read for `P_TP != D_TP` work. +- `stats.py`: NIXL transfer stats and Prometheus metrics. +- `utils.py`: small utilities local to the NIXL package. + +## NIXL contribution reading path + +1. `connector.py`: understand which calls route to scheduler vs worker. +2. `scheduler.py`: follow `get_num_new_matched_tokens()`, + `update_state_after_alloc()`, `build_connector_meta()`, and + `request_finished()`. +3. `metadata.py`: map scheduler metadata to worker transfer inputs. +4. `worker.py`: read `add_remote_agent()`, registration helpers, load paths, + prefix-cache post-processing, and completion accounting. +5. `tp_mapping.py`: read before touching heterogeneous TP, replicated KV + heads, MLA, or hybrid SSM/GDN/KDA. +6. `stats.py`: extend when adding transfer modes, new failure classes, or + latency counters. + +## Common change areas + +- Heterogeneous TP: `tp_mapping.py`, `worker.py`, and shared topology helpers in + `kv_connector/utils.py`. +- Hybrid SSM/GDN/KDA: `worker.py` plus + `v1/ssm_conv_transfer_utils.py`. +- Prefix caching in P/D mode: `scheduler.py`, `worker.py`, and HMA coordinator + code under `vllm/v1/core/`. +- Connector observability: `stats.py`, `metrics.py`, and failure logging in + `worker.py`. + +## Useful tests and docs + +- Unit tests: `tests/v1/kv_connector/unit/test_nixl_connector.py`. +- HMA/hybrid tests: `tests/v1/kv_connector/unit/test_nixl_connector_hma.py`. +- TP mapping tests: `tests/v1/kv_connector/unit/test_tp_mapping.py`. +- Integration examples: `tests/v1/kv_connector/nixl_integration/`. +- User docs: `docs/features/nixl_connector_usage.md` and + `docs/features/nixl_connector_compatibility.md`. From 9b0c349181ca70c2c379a0e148ec0175040052f8 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 01:13:14 +0800 Subject: [PATCH 19/36] more Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 31 ++++++++++++++++--- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 01f4380ad344..d7d32c5008bc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2803,10 +2803,29 @@ def _spot_check_kv_after_recv(self, req_id: str, meta: ReqMeta) -> None: if bid not in bids_to_check: bids_to_check.append(bid) + # Per-group block IDs: for correct spot-checking, each layer type + # should only check its own group's blocks. + # Groups: MambaSpec×N + MLAAttentionSpec×1 + mamba_group_bids: list[int] = [] + attn_group_bids: list[int] = [] + for gi, bids in enumerate(block_ids): + if gi < len(self._group_spec_types) and _is_attention_spec( + self._group_spec_types[gi] + ): + attn_group_bids.extend(bids[:2]) + else: + mamba_group_bids.extend(bids[:2]) + # Helper to spot-check one layer's transferred blocks - def _check_layer(label: str, name: str, t: torch.Tensor) -> None: + def _check_layer( + label: str, + name: str, + t: torch.Tensor, + group_bids: list[int] | None = None, + ) -> None: + check_bids = group_bids if group_bids else bids_to_check[:4] block_stats = "" - for bid in bids_to_check[:3]: + for bid in check_bids[:4]: if bid < t.shape[0]: blk = t[bid].float() block_stats += ( @@ -2836,21 +2855,23 @@ def _check_layer(label: str, name: str, t: torch.Tensor) -> None: continue if ".attn" in layer_name: continue # skip MLA entries here - _check_layer("SSM", layer_name, t) + _check_layer("SSM", layer_name, t, mamba_group_bids) checked += 1 # CRITICAL: Check ALL MLA layers (keys containing ".self_attn.attn") + # Use attn_group_bids so we check the MLA group's actual blocks, + # not the KDA group's blocks. mla_count = 0 for layer_name, cache_val in self.device_kv_caches.items(): if ".attn" not in layer_name: continue if isinstance(cache_val, torch.Tensor): - _check_layer("MLA", layer_name, cache_val) + _check_layer("MLA", layer_name, cache_val, attn_group_bids) mla_count += 1 elif isinstance(cache_val, (tuple, list)): for i, c in enumerate(cache_val): if isinstance(c, torch.Tensor): - _check_layer(f"MLA[{i}]", layer_name, c) + _check_layer(f"MLA[{i}]", layer_name, c, attn_group_bids) mla_count += 1 if checked == 0 and mla_count == 0: From 01ddbad2e07d88b2edb70f957b8f31b9aed9e336 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 10:11:01 +0800 Subject: [PATCH 20/36] check ssm dtype Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index d7d32c5008bc..bbc692c680b3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -2844,7 +2844,7 @@ def _check_layer( block_stats, ) - # Check first 2 KDA layers (conv state from tuple) + # Check first 2 KDA layers (conv_state bf16 + ssm_state fp32 from tuple) checked = 0 for layer_name, cache_val in list(self.device_kv_caches.items()): if checked >= 2: @@ -2855,7 +2855,34 @@ def _check_layer( continue if ".attn" in layer_name: continue # skip MLA entries here - _check_layer("SSM", layer_name, t, mamba_group_bids) + # Check conv_state (bf16) — index 0 + _check_layer("SSM-conv", layer_name, t, mamba_group_bids) + # Check ssm_state (fp32) — index 1 + if len(cache_val) > 1: + ssm_t = cache_val[1] + if isinstance(ssm_t, torch.Tensor): + _check_layer( + "SSM-ssm(fp32)", layer_name, ssm_t, mamba_group_bids + ) + # Verify ssm_state view alignment: data_ptr should be + # offset past conv_state within the shared tensor. + conv_ptr = t.data_ptr() + ssm_ptr = ssm_t.data_ptr() + ptr_diff = ssm_ptr - conv_ptr + if self._mamba_ssm_size is not None: + expected_conv_bytes = self._mamba_ssm_size[0] + logger.warning( + "[DESC-DEBUG] SSM alignment: " + "conv_ptr=0x%x, ssm_ptr=0x%x, " + "ptr_diff=%d, expected_conv=%d, " + "match=%s, ssm_stride=%s", + conv_ptr, + ssm_ptr, + ptr_diff, + expected_conv_bytes, + ptr_diff == expected_conv_bytes, + ssm_t.stride(), + ) checked += 1 # CRITICAL: Check ALL MLA layers (keys containing ".self_attn.attn") From 6822236d3b8bb4b9064e5c99e3e0df361a494413 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 12:00:19 +0800 Subject: [PATCH 21/36] clean up debug logging in PD disagg Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/README.md | 60 -- .../kv_connector/v1/nixl/worker.py | 511 +----------------- 2 files changed, 6 insertions(+), 565 deletions(-) delete mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/nixl/README.md diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/README.md b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/README.md deleted file mode 100644 index eda6293efdf6..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/README.md +++ /dev/null @@ -1,60 +0,0 @@ -# NIXL Connector - -Source location: `vllm/distributed/kv_transfer/kv_connector/v1/nixl`. - -This directory implements `NixlConnector`, the main high-performance -Prefill/Decode disaggregation connector. It handles the v1 scheduler/worker -split, NIXL agent metadata, heterogeneous TP mapping, HMA/hybrid SSM support, -transfer failure reporting, lease/heartbeat behavior, and connector metrics. - -## Key source files - -- `connector.py`: facade class `NixlConnector`; delegates scheduler role to - `NixlConnectorScheduler` and worker role to `NixlConnectorWorker`. -- `scheduler.py`: request lifecycle on the scheduler side: new requests, remote - hit accounting, allocation updates, connector metadata, finished requests, - lease and timeout state. -- `worker.py`: NIXL worker implementation: memory registration, handshake, - remote-agent tracking, transfer reads/writes, HMA handling, prefix-cache - post-processing, heartbeats, and completion reporting. -- `metadata.py`: wire and in-process metadata structures, including - `NixlAgentMetadata`, `NixlHandshakePayload`, `RemoteMeta`, `ReqMeta`, and - `NixlConnectorMetadata`. -- `tp_mapping.py`: local-to-remote TP mapping for heterogeneous deployments. - This is the first file to read for `P_TP != D_TP` work. -- `stats.py`: NIXL transfer stats and Prometheus metrics. -- `utils.py`: small utilities local to the NIXL package. - -## NIXL contribution reading path - -1. `connector.py`: understand which calls route to scheduler vs worker. -2. `scheduler.py`: follow `get_num_new_matched_tokens()`, - `update_state_after_alloc()`, `build_connector_meta()`, and - `request_finished()`. -3. `metadata.py`: map scheduler metadata to worker transfer inputs. -4. `worker.py`: read `add_remote_agent()`, registration helpers, load paths, - prefix-cache post-processing, and completion accounting. -5. `tp_mapping.py`: read before touching heterogeneous TP, replicated KV - heads, MLA, or hybrid SSM/GDN/KDA. -6. `stats.py`: extend when adding transfer modes, new failure classes, or - latency counters. - -## Common change areas - -- Heterogeneous TP: `tp_mapping.py`, `worker.py`, and shared topology helpers in - `kv_connector/utils.py`. -- Hybrid SSM/GDN/KDA: `worker.py` plus - `v1/ssm_conv_transfer_utils.py`. -- Prefix caching in P/D mode: `scheduler.py`, `worker.py`, and HMA coordinator - code under `vllm/v1/core/`. -- Connector observability: `stats.py`, `metrics.py`, and failure logging in - `worker.py`. - -## Useful tests and docs - -- Unit tests: `tests/v1/kv_connector/unit/test_nixl_connector.py`. -- HMA/hybrid tests: `tests/v1/kv_connector/unit/test_nixl_connector_hma.py`. -- TP mapping tests: `tests/v1/kv_connector/unit/test_tp_mapping.py`. -- Integration examples: `tests/v1/kv_connector/nixl_integration/`. -- User docs: `docs/features/nixl_connector_usage.md` and - `docs/features/nixl_connector_compatibility.md`. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index bbc692c680b3..ff60c0b912fe 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -82,10 +82,6 @@ logger = init_logger(__name__) -# Diagnostic flag for descriptor debugging. Set to True to enable verbose -# logging of KV cache registration, descriptor building, and transfer data. -_DESC_DEBUG = True - class NixlConnectorWorker: """Implementation of Worker side methods""" @@ -105,25 +101,6 @@ def _compute_desc_ids( if block_size_ratio is not None: num_blocks = int(num_blocks * block_size_ratio) num_fa_descs = num_fa_regions * num_blocks - - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] _compute_desc_ids: num_fa_regions=%s, " - "num_ssm_regions=%s, dst_num_blocks=%s, block_size_ratio=%s, " - "physical_blocks_per_logical=%s => num_blocks=%s, " - "num_fa_descs=%s (FA/Mamba boundary), " - "block_ids_lengths=%s, group_spec_types=%s", - num_fa_regions, - num_ssm_regions, - dst_num_blocks, - block_size_ratio, - physical_blocks_per_logical, - num_blocks, - num_fa_descs, - [len(g) for g in block_ids], - [t.__name__ for t in self._group_spec_types], - ) - # All-attention fast path: single vectorized broadcast. if num_ssm_regions == 0: # NOTE (NickLucche) With HMA, every kv group has the same number of layers @@ -170,15 +147,6 @@ def _compute_desc_ids( ) result = np.concatenate(all_descs) - if _DESC_DEBUG: - sample = result[:10].tolist() if len(result) > 10 else result.tolist() - logger.warning( - "[DESC-DEBUG] _compute_desc_ids result: total=%d, " - "per_group_counts=%s, sample_ids=%s", - len(result), - [len(d) for d in all_descs], - sample, - ) return result def _build_local_splits_from_plan( @@ -900,31 +868,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if isinstance(layer_spec, UniformTypeKVCacheSpecs): # MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs layer_spec = layer_spec.kv_cache_specs[layer_name] - if _DESC_DEBUG: - cache_type = type(cache_or_caches).__name__ - if isinstance(cache_or_caches, torch.Tensor): - cache_info = f"Tensor(shape={tuple(cache_or_caches.shape)}, dtype={cache_or_caches.dtype})" - elif isinstance(cache_or_caches, (tuple, list)): - parts = [] - for c in cache_or_caches: - if isinstance(c, torch.Tensor): - parts.append( - f"Tensor(shape={tuple(c.shape)}, ptr=0x{c.data_ptr():x})" - ) - else: - parts.append(type(c).__name__) - cache_info = f"({', '.join(parts)})" - else: - cache_info = cache_type - logger.warning( - "[DESC-DEBUG] register layer=%s: spec=%s, is_ssm=%s, " - "cache_type=%s, cache_info=%s", - layer_name, - type(layer_spec).__name__, - isinstance(layer_spec, MambaSpec), - cache_type, - cache_info, - ) cache_list = self.transfer_topo.get_transfer_cache_regions( cache_or_caches, layer_spec ) @@ -974,39 +917,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # regions. if is_attn: self._attn_block_len[idx] = physical_page_size - if _DESC_DEBUG: - ssm_stride = self.block_len_per_layer[idx] - logger.warning( - "[DESC-DEBUG] DUAL-PURPOSE region %d: " - "ssm_block_len=%d, attn_block_len=%d, " - "stride_mismatch=%s", - idx, - ssm_stride, - physical_page_size, - ssm_stride != physical_page_size, - ) - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] SKIP layer=%s: base=0x%x already seen " - "(spec=%s, is_ssm=%s, is_attn=%s)", - layer_name, - base_addr, - type(layer_spec).__name__, - is_ssm, - is_attn, - ) continue - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] REGISTER layer=%s: base=0x%x shape=%s " - "spec=%s is_ssm=%s page_size=%d", - layer_name, - base_addr, - tuple(cache.shape), - type(layer_spec).__name__, - isinstance(layer_spec, MambaSpec), - physical_page_size, - ) seen_base_addresses.append(base_addr) is_ssm = isinstance(layer_spec, MambaSpec) is_attn = not is_ssm @@ -1053,36 +964,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): "Different block lengths collected: %s", set(self.block_len_per_layer) ) assert len(self.block_len_per_layer) == len(seen_base_addresses) - - if _DESC_DEBUG: - dual_purpose = [ - i - for i in range(len(self._is_ssm_region)) - if i < len(self._is_attn_region) - and self._is_ssm_region[i] - and self._is_attn_region[i] - ] - logger.warning( - "[DESC-DEBUG] register_kv_caches summary: " - "num_unique_regions=%d, is_ssm_region=%s, " - "block_len_per_layer=%s, attn_block_len=%s, " - "dual_purpose_regions=%s, group_spec_types=%s, " - "has_mamba=%s, use_mla=%s, " - "physical_blocks_per_logical=%s, logical_num_blocks=%s, " - "num_blocks=%s", - len(seen_base_addresses), - self._is_ssm_region, - self.block_len_per_layer, - self._attn_block_len, - dual_purpose, - [t.__name__ for t in self._group_spec_types], - self._has_mamba, - self.use_mla, - self._physical_blocks_per_logical_kv_block, - self._logical_num_blocks, - self.num_blocks, - ) - # Expand NIXL registration size for dual-purpose regions. KDA # registers first with its stride (542720 bytes/block), but the # physical allocation uses max(KDA, MLA) = 552960 bytes/block. @@ -1104,15 +985,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if old_size != new_size: base, _, device, label = caches_data[idx] caches_data[idx] = (base, new_size, device, label) - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] DUAL-PURPOSE registration size: " - "region %d, original=%d, expanded=%d", - idx, - old_size, - new_size, - ) - self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses # FA regions count only attention layers. SSM/Mamba regions are # served by Mamba descriptors, not FA descriptors, and must be @@ -1123,18 +995,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # HMA dual-purpose regions (both SSM and attn) are counted here. num_attention_base = sum(self._is_attn_region) self.num_regions = num_attention_base - - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] Region counts: total_caches_data=%d, " - "num_ssm_regions=%d, num_attention_base=%d, " - "num_regions_before_split=%d", - len(caches_data), - sum(self._is_ssm_region), - num_attention_base, - self.num_regions, - ) - if self.transfer_topo.virtually_split_kv_in_blocks: # NOTE (NickLucche) When FlashInfer is used, memory is registered # with joint KV for each block. This minimizes the overhead in @@ -1228,28 +1088,7 @@ def _build_mamba_local( # outside their registered memory, causing prepXferDlist failures. # Dual-purpose regions (HMA) get both FA and Mamba descs. if not self._is_ssm_region[i]: - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] _build_mamba_local: SKIP region %d " - "(not SSM), base=0x%x", - i, - base_addr, - ) continue - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] _build_mamba_local: BUILD region %d " - "(SSM), base=0x%x, page_stride=%d, conv_offsets=%s, " - "ssm_size=%d, num_blocks=%d", - i, - base_addr, - self.block_len_per_layer[i] - // block_size_ratio - * physical_per_logical, - conv_offsets, - ssm_size, - num_blocks, - ) # Jump one page_size, but ssm page_size may be bigger when kernel # locks block size to a specific value (physical_per_logical scale). page_stride = ( @@ -1304,28 +1143,7 @@ def _build_mamba_remote( # Only build Mamba descriptors for SSM/Mamba regions. # Attention/MLA regions do not contain conv or temporal state. if i < len(self._is_ssm_region) and not self._is_ssm_region[i]: - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] _build_mamba_remote: SKIP region %d " - "(not SSM), remote_base=0x%x", - i, - base_addr, - ) continue - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] _build_mamba_remote: BUILD region %d " - "(SSM), remote_base=0x%x, page_stride=%d, " - "conv_offsets=%s, ssm_read_size=%d, " - "local_offset=%d, num_blocks=%d", - i, - base_addr, - nixl_agent_meta.block_lens[i] * remote_physical_per_logical, - conv_offsets, - ssm_read_size, - local_offset, - num_blocks, - ) page_stride = nixl_agent_meta.block_lens[i] * remote_physical_per_logical for off, sz in conv_offsets: for blk in range(num_blocks): @@ -1355,12 +1173,6 @@ def _build_fa_local( # SSM/Mamba regions are served by Mamba descriptors. # Dual-purpose regions (HMA) get both FA and Mamba descs. if i < len(self._is_attn_region) and not self._is_attn_region[i]: - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] _build_fa_local: SKIP region %d (not attn), base=0x%x", - i, - base_addr, - ) continue # For dual-purpose regions (HMA), block_len_per_layer stores the # KDA/SSM stride. FA descriptors must use the attention spec's @@ -1373,14 +1185,6 @@ def _build_fa_local( kv_block_len = attn_stride kv_block_len = kv_block_len // block_size_ratio page_stride = attn_stride // block_size_ratio - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] _build_fa_local: DUAL-PURPOSE region %d, " - "using attn_stride=%d (ssm_stride=%d)", - i, - page_stride, - self.block_len_per_layer[i] // block_size_ratio, - ) else: kv_block_len = ( self.get_backend_aware_kv_block_len( @@ -1389,18 +1193,6 @@ def _build_fa_local( // block_size_ratio ) page_stride = self.block_len_per_layer[i] // block_size_ratio - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] _build_fa_local: BUILD region %d " - "(attn), base=0x%x, kv_block_len=%d, page_stride=%d, " - "num_blocks=%d, virtual_split=%s", - i, - base_addr, - kv_block_len, - page_stride, - num_blocks, - self.transfer_topo.virtually_split_kv_in_blocks, - ) for block_id in range(num_blocks): block_offset = block_id * page_stride addr = base_addr + block_offset @@ -1442,30 +1234,7 @@ def _build_fa_remote( # SSM/Mamba regions are served by Mamba descriptors. # Dual-purpose regions (HMA) get both FA and Mamba descs. if i < len(self._is_attn_region) and not self._is_attn_region[i]: - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] _build_fa_remote: SKIP region %d " - "(not attn), remote_base=0x%x", - i, - base_addr, - ) continue - if _DESC_DEBUG: - local_block_len = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=True, mamba_view=False - ) - logger.warning( - "[DESC-DEBUG] _build_fa_remote: BUILD region %d " - "(attn), remote_base=0x%x, local_block_len=%d, " - "remote_block_len=%d, num_blocks=%d, " - "num_attn_reads=%d", - i, - base_addr, - local_block_len, - nixl_agent_meta.block_lens[i], - num_blocks, - num_attn_reads, - ) # Read our whole local region size from remote.. # For dual-purpose regions, use the attention spec's stride instead # of block_len_per_layer (which stores KDA's stride). @@ -1475,14 +1244,6 @@ def _build_fa_remote( local_block_len = attn_stride // 2 else: local_block_len = attn_stride - if _DESC_DEBUG: - logger.warning( - "[DESC-DEBUG] _build_fa_remote: DUAL-PURPOSE region %d, " - "using attn_stride=%d (remote_block_lens=%d)", - i, - attn_stride, - nixl_agent_meta.block_lens[i], - ) else: local_block_len = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=True, mamba_view=False @@ -1553,7 +1314,6 @@ def register_local_xfer_handler( local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank] blocks_data = self._build_fa_local(local_base_addresses, block_size_ratio) - fa_desc_count = len(blocks_data) logger.debug( "Created %s blocks for src engine %s and rank %s on device id %s", len(blocks_data), @@ -1573,44 +1333,6 @@ def register_local_xfer_handler( blocks_data.extend( self._build_mamba_local(local_base_addresses, block_size_ratio) ) - - if _DESC_DEBUG: - mamba_desc_count = len(blocks_data) - fa_desc_count - logger.warning( - "[DESC-DEBUG] register_local_xfer_handler: " - "fa_desc_count=%d, mamba_desc_count=%d, total=%d, " - "num_descs(boundary)=%d, num_regions=%d, num_blocks=%d, " - "is_ssm_region=%s", - fa_desc_count, - mamba_desc_count, - len(blocks_data), - self.num_descs, - self.num_regions, - self.num_blocks, - self._is_ssm_region, - ) - # Spot-check first/last FA desc and first/last Mamba desc - if blocks_data: - logger.warning( - "[DESC-DEBUG] local FA desc[0]: addr=0x%x size=%d dev=%d", - *blocks_data[0], - ) - if fa_desc_count > 1: - logger.warning( - "[DESC-DEBUG] local FA desc[%d]: addr=0x%x size=%d dev=%d", - fa_desc_count - 1, - *blocks_data[fa_desc_count - 1], - ) - if mamba_desc_count > 0: - logger.warning( - "[DESC-DEBUG] local Mamba desc[0]: addr=0x%x size=%d dev=%d", - *blocks_data[fa_desc_count], - ) - logger.warning( - "[DESC-DEBUG] local Mamba desc[last]: addr=0x%x size=%d dev=%d", - *blocks_data[-1], - ) - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs), blocks_data @@ -1793,7 +1515,6 @@ def add_remote_agent( nixl_agent_meta, block_size_ratio, ) - remote_fa_count = len(blocks_data) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", len(blocks_data), @@ -1814,36 +1535,6 @@ def add_remote_agent( transfer_info, ) ) - - if _DESC_DEBUG: - remote_mamba_count = len(blocks_data) - remote_fa_count - logger.warning( - "[DESC-DEBUG] add_remote_agent: remote_fa_descs=%d, " - "remote_mamba_descs=%d, total_remote=%d, " - "tp_ratio=%s, block_size_ratio=%s, " - "remote_num_blocks=%s, remote_base_addrs=%d, " - "remote_block_lens=%s", - remote_fa_count, - remote_mamba_count, - len(blocks_data), - tp_ratio, - block_size_ratio, - nixl_agent_meta.num_blocks, - len(nixl_agent_meta.kv_caches_base_addr), - nixl_agent_meta.block_lens, - ) - # Spot-check remote descriptors - if blocks_data: - logger.warning( - "[DESC-DEBUG] remote FA desc[0]: addr=0x%x size=%d dev=%d", - *blocks_data[0], - ) - if remote_mamba_count > 0: - logger.warning( - "[DESC-DEBUG] remote Mamba desc[0]: addr=0x%x size=%d dev=%d", - *blocks_data[remote_fa_count], - ) - # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) self.dst_xfer_side_handles[engine_id][remote_tp_rank] = ( @@ -2183,10 +1874,6 @@ def get_finished(self) -> tuple[set[str], set[str]]: meta = self._recving_metadata.pop(req_id, None) assert meta is not None, f"{req_id} not found in recving_metadata list" - # Spot-check KV tensor values after transfer - if _DESC_DEBUG and meta.remote is not None: - self._spot_check_kv_after_recv(req_id, meta) - # Skip KV sync and post-processing for failed requests if req_id in failed_recv_reqs: logger.warning( @@ -2563,7 +2250,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_xfer_side_handle=remote_xfer_side_handle, ) - if self.use_mla and tp_ratio < 0 and read_specs: + if self.use_mla and tp_ratio < 0 and read_specs and not self._has_mamba: # MLA attention KV is replicated across P ranks. D reads from # one P rank, but other P ranks also hold the same KV and need # to be notified so they can release blocks early (vs. timeout). @@ -2576,12 +2263,11 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): # request" errors because the P-side request tracking dicts # (_reqs_to_send / _reqs_to_process) are only populated on the # decode side. - if not self._has_mamba: - notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() - remote_agents = self._remote_agents[meta.remote.engine_id] - for rank_to_notify, agent in remote_agents.items(): - if rank_to_notify != read_specs[0].remote_rank: - self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) + notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() + remote_agents = self._remote_agents[meta.remote.engine_id] + for rank_to_notify, agent in remote_agents.items(): + if rank_to_notify != read_specs[0].remote_rank: + self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) def _read_blocks( self, @@ -2696,38 +2382,6 @@ def _read_blocks( ) assert len(local_block_descs_ids) == len(remote_block_descs_ids) - - if _DESC_DEBUG: - local_sample = local_block_descs_ids[:10].tolist() - remote_sample = remote_block_descs_ids[:10].tolist() - logger.warning( - "[DESC-DEBUG] _read_blocks req=%s: num_desc_pairs=%d, " - "local_block_ids=%s, remote_block_ids=%s, " - "local_desc_sample=%s, remote_desc_sample=%s", - request_id, - len(local_block_descs_ids), - [ids[:3] for ids in local_block_ids], - [ids[:3] for ids in remote_block_ids], - local_sample, - remote_sample, - ) - # Spot-check descriptor content (addr/size) from src/dst blocks - if hasattr(self, "src_blocks_data") and self.src_blocks_data: - for check_idx in [0, -1]: - if check_idx < len(local_block_descs_ids): - did = local_block_descs_ids[check_idx] - if 0 <= did < len(self.src_blocks_data): - addr, sz, dev = self.src_blocks_data[did] - logger.warning( - "[DESC-DEBUG] local_desc[%d]=%d => " - "src_blocks_data: addr=0x%x size=%d dev=%d", - check_idx, - did, - addr, - sz, - dev, - ) - # Prepare transfer with Nixl. handle = None try: @@ -2757,159 +2411,6 @@ def _read_blocks( ) self._handle_failed_transfer(request_id, handle) - # Counter for spot-check rate limiting - _spot_check_count: int = 0 - - def _spot_check_kv_after_recv(self, req_id: str, meta: ReqMeta) -> None: - """Spot-check KV tensor values after a transfer completes.""" - NixlConnectorWorker._spot_check_count += 1 - if NixlConnectorWorker._spot_check_count > 5: - return # Only check the first few transfers - - num_kv_caches = len(self.device_kv_caches) - logger.warning( - "[DESC-DEBUG] _spot_check_kv_after_recv req=%s " - "num_kv_caches=%d local_block_ids=%s", - req_id, - num_kv_caches, - [ids[:3] for ids in meta.local_physical_block_ids], - ) - - # Dump ALL keys with types and shapes so we can find MLA entries - all_keys_info = [] - for k, v in self.device_kv_caches.items(): - if isinstance(v, torch.Tensor): - all_keys_info.append(f"{k}=Tensor{tuple(v.shape)}") - elif isinstance(v, (tuple, list)): - parts = [] - for c in v: - if isinstance(c, torch.Tensor): - parts.append(f"T{tuple(c.shape)}") - else: - parts.append(type(c).__name__) - all_keys_info.append(f"{k}=({','.join(parts)})") - else: - all_keys_info.append(f"{k}={type(v).__name__}") - logger.warning("[DESC-DEBUG] device_kv_caches keys: %s", all_keys_info) - - # Memory-efficient spot check: only examine transferred blocks, - # not the full tensor (which can be >100 GB and cause OOM). - block_ids = meta.local_physical_block_ids - - # Collect unique block IDs to check (across all groups) - bids_to_check: list[int] = [] - for bids in block_ids[:4]: - for bid in bids[:2]: - if bid not in bids_to_check: - bids_to_check.append(bid) - - # Per-group block IDs: for correct spot-checking, each layer type - # should only check its own group's blocks. - # Groups: MambaSpec×N + MLAAttentionSpec×1 - mamba_group_bids: list[int] = [] - attn_group_bids: list[int] = [] - for gi, bids in enumerate(block_ids): - if gi < len(self._group_spec_types) and _is_attention_spec( - self._group_spec_types[gi] - ): - attn_group_bids.extend(bids[:2]) - else: - mamba_group_bids.extend(bids[:2]) - - # Helper to spot-check one layer's transferred blocks - def _check_layer( - label: str, - name: str, - t: torch.Tensor, - group_bids: list[int] | None = None, - ) -> None: - check_bids = group_bids if group_bids else bids_to_check[:4] - block_stats = "" - for bid in check_bids[:4]: - if bid < t.shape[0]: - blk = t[bid].float() - block_stats += ( - f" blk[{bid}]:mean={blk.mean():.6f}" - f",std={blk.std():.6f}" - f",nan={torch.isnan(blk).any().item()}" - f",inf={torch.isinf(blk).any().item()}" - f",f4={blk.flatten()[:4].tolist()}" - ) - logger.warning( - "[DESC-DEBUG] %s layer=%s shape=%s dtype=%s%s", - label, - name, - tuple(t.shape), - t.dtype, - block_stats, - ) - - # Check first 2 KDA layers (conv_state bf16 + ssm_state fp32 from tuple) - checked = 0 - for layer_name, cache_val in list(self.device_kv_caches.items()): - if checked >= 2: - break - if isinstance(cache_val, (tuple, list)): - t = cache_val[0] if len(cache_val) > 0 else None - if t is None or not isinstance(t, torch.Tensor): - continue - if ".attn" in layer_name: - continue # skip MLA entries here - # Check conv_state (bf16) — index 0 - _check_layer("SSM-conv", layer_name, t, mamba_group_bids) - # Check ssm_state (fp32) — index 1 - if len(cache_val) > 1: - ssm_t = cache_val[1] - if isinstance(ssm_t, torch.Tensor): - _check_layer( - "SSM-ssm(fp32)", layer_name, ssm_t, mamba_group_bids - ) - # Verify ssm_state view alignment: data_ptr should be - # offset past conv_state within the shared tensor. - conv_ptr = t.data_ptr() - ssm_ptr = ssm_t.data_ptr() - ptr_diff = ssm_ptr - conv_ptr - if self._mamba_ssm_size is not None: - expected_conv_bytes = self._mamba_ssm_size[0] - logger.warning( - "[DESC-DEBUG] SSM alignment: " - "conv_ptr=0x%x, ssm_ptr=0x%x, " - "ptr_diff=%d, expected_conv=%d, " - "match=%s, ssm_stride=%s", - conv_ptr, - ssm_ptr, - ptr_diff, - expected_conv_bytes, - ptr_diff == expected_conv_bytes, - ssm_t.stride(), - ) - checked += 1 - - # CRITICAL: Check ALL MLA layers (keys containing ".self_attn.attn") - # Use attn_group_bids so we check the MLA group's actual blocks, - # not the KDA group's blocks. - mla_count = 0 - for layer_name, cache_val in self.device_kv_caches.items(): - if ".attn" not in layer_name: - continue - if isinstance(cache_val, torch.Tensor): - _check_layer("MLA", layer_name, cache_val, attn_group_bids) - mla_count += 1 - elif isinstance(cache_val, (tuple, list)): - for i, c in enumerate(cache_val): - if isinstance(c, torch.Tensor): - _check_layer(f"MLA[{i}]", layer_name, c, attn_group_bids) - mla_count += 1 - - if checked == 0 and mla_count == 0: - logger.warning( - "[DESC-DEBUG] No torch.Tensor found in device_kv_caches. Types: %s", - { - k: type(v).__name__ - for k, v in list(self.device_kv_caches.items())[:4] - }, - ) - def get_mapped_blocks( self, block_ids: np.ndarray, block_size_ratio: int ) -> np.ndarray: From 235643820f53ff36e5a7ef119ef00c1c4dd6d9fd Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 12:38:50 +0800 Subject: [PATCH 22/36] resupport heter TP Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 82 ++++++++++++++----- 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index ff60c0b912fe..b8cc43458498 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -840,7 +840,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): caches_data = [] # With hybrid allocator, layers can share a kv cache tensor - seen_base_addresses = [] + seen_base_addresses: list[int] = [] # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -1070,10 +1070,8 @@ def _build_mamba_local( ) -> list[tuple[int, int, int]]: """Build 4 desc regions (x, B, C, ssm) per layer for local mamba blocks, enabling the 3-read transfer with DS conv layout.""" - assert block_size_ratio == 1, ( - "Mamba 3-read transfer with block_size_ratio != 1 is not tested. " - f"Got block_size_ratio={block_size_ratio}." - ) + # block_size_ratio is not used here because local descriptors always + # use local strides. Only remote descriptors need ratio scaling. assert self._conv_decomp is not None conv_offsets = self._conv_decomp.local_conv_offsets conv_size, ssm_size = self._mamba_ssm_size @@ -1177,14 +1175,15 @@ def _build_fa_local( # For dual-purpose regions (HMA), block_len_per_layer stores the # KDA/SSM stride. FA descriptors must use the attention spec's # stride instead so they address MLA data correctly. + # MLA stride is TP-independent (num_kv_heads=1, head_dim + # constant), so we skip block_size_ratio scaling entirely. if i in self._attn_block_len: attn_stride = self._attn_block_len[i] if self.transfer_topo.virtually_split_kv_in_blocks: kv_block_len = attn_stride // 2 else: kv_block_len = attn_stride - kv_block_len = kv_block_len // block_size_ratio - page_stride = attn_stride // block_size_ratio + page_stride = attn_stride else: kv_block_len = ( self.get_backend_aware_kv_block_len( @@ -1238,29 +1237,30 @@ def _build_fa_remote( # Read our whole local region size from remote.. # For dual-purpose regions, use the attention spec's stride instead # of block_len_per_layer (which stores KDA's stride). + # MLA stride is TP-independent, so skip block_size_ratio scaling. if i in self._attn_block_len: attn_stride = self._attn_block_len[i] if self.transfer_topo.virtually_split_kv_in_blocks: local_block_len = attn_stride // 2 else: local_block_len = attn_stride + remote_kv_block_len = local_block_len else: local_block_len = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=True, mamba_view=False ) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - # ..using remote kv_block_len as transfer unit - local_block_len = remote_kv_block_len + remote_kv_block_len = local_block_len // block_size_ratio + if block_size_ratio > 1: + # ..using remote kv_block_len as transfer unit + local_block_len = remote_kv_block_len local_block_len = local_block_len // num_attn_reads rank_offset = plan.rank_offset_factor * remote_kv_block_len # For dual-purpose regions, use the attention stride as page size # so descriptors step through remote memory at MLA's stride. - # NOTE: This assumes homogeneous TP where both sides have the same - # attention spec stride. For heterogeneous TP, the remote's attn - # stride must be propagated via metadata. + # MLA stride is TP-independent (num_kv_heads=1, head_dim constant), + # so the local _attn_block_len equals the remote's MLA stride. if i in self._attn_block_len: page_size = self._attn_block_len[i] else: @@ -1597,10 +1597,17 @@ def _validate_remote_agent_handshake( "Disable prefix caching with --no-enable-prefix-caching." ) - if self._is_hma_required: - assert block_size_ratio == 1, ( - "HMA does not support different remote block size yet" - ) + if ( + self._is_hma_required + and block_size_ratio != 1 + and not (self.use_mla and self._has_mamba) + ): + # For hybrid MLA+SSM models, block_size_ratio reflects SSM + # dimension scaling across different TP sizes. MLA attention + # stride is TP-independent (replicated KV), so FA descriptors + # for dual-purpose regions are safe. SSM descriptors use their + # own hetero-TP-aware remote_conv_offsets. + raise AssertionError("HMA does not support different remote block size yet") kv_cache_layout = ( self.kv_cache_layout if not self.use_host_buffer @@ -1660,9 +1667,42 @@ def _validate_remote_agent_handshake( remote_block_len = nixl_agent_meta.block_lens[0] if self.use_mla or self.transfer_topo.is_kv_replicated(remote_engine_id): # With replicated KV cache, only the number of blocks can differ. - # TODO (ZhanqiuHu): For mamba models, validate FA and mamba - # block_lens separately. - if not self._has_mamba: + if self._has_mamba and self._is_hma_required: + # Hybrid MLA+SSM with HMA: dual-purpose regions have both + # SSM and attention block lengths. Validate separately. + for i in range(len(self.block_len_per_layer)): + is_ssm = i < len(self._is_ssm_region) and self._is_ssm_region[i] + is_attn = i < len(self._is_attn_region) and self._is_attn_region[i] + remote_bl = nixl_agent_meta.block_lens[i] + if is_ssm and not is_attn: + # Pure SSM region: stride scales with TP. + assert ( + self.block_len_per_layer[i] // block_size_ratio == remote_bl + ), ( + f"SSM region {i} block_len mismatch: " + f"local={self.block_len_per_layer[i]} // " + f"ratio={block_size_ratio} != remote={remote_bl}" + ) + elif is_attn and not is_ssm: + # Pure attention region: MLA is replicated (TP-independent). + assert self.block_len_per_layer[i] == remote_bl, ( + f"Attention region {i} block_len mismatch: " + f"local={self.block_len_per_layer[i]} != remote={remote_bl}" + ) + elif is_ssm and is_attn: + # Dual-purpose region: block_len_per_layer stores + # SSM stride (may differ with TP). MLA stride is + # stored in _attn_block_len and is TP-independent. + # Remote block_lens stores the remote's SSM stride + # (which it registered first, like local). + assert ( + self.block_len_per_layer[i] // block_size_ratio == remote_bl + ), ( + f"Dual-purpose region {i} SSM stride mismatch: " + f"local={self.block_len_per_layer[i]} // " + f"ratio={block_size_ratio} != remote={remote_bl}" + ) + else: for i in range(len(self.block_len_per_layer)): assert ( self.block_len_per_layer[i] // block_size_ratio From c3b19e35fb8a8834c0c443bdcdd4bff6efee4af7 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 13:01:03 +0800 Subject: [PATCH 23/36] fix P4D2 Signed-off-by: JaredforReal --- .../kv_transfer/kv_connector/v1/nixl/worker.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index b8cc43458498..312f0470711d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1689,12 +1689,11 @@ def _validate_remote_agent_handshake( f"Attention region {i} block_len mismatch: " f"local={self.block_len_per_layer[i]} != remote={remote_bl}" ) - elif is_ssm and is_attn: + elif is_ssm and is_attn and block_size_ratio != 1: # Dual-purpose region: block_len_per_layer stores # SSM stride (may differ with TP). MLA stride is # stored in _attn_block_len and is TP-independent. - # Remote block_lens stores the remote's SSM stride - # (which it registered first, like local). + # Remote block_lens stores the remote's SSM stride. assert ( self.block_len_per_layer[i] // block_size_ratio == remote_bl ), ( @@ -1702,6 +1701,11 @@ def _validate_remote_agent_handshake( f"local={self.block_len_per_layer[i]} // " f"ratio={block_size_ratio} != remote={remote_bl}" ) + # When block_size_ratio == 1 for dual-purpose regions + # (fallback from non-exact byte division), SSM strides + # differ by approximately |tp_ratio| but page padding + # makes the ratio inexact. Skip strict validation — + # remote_conv_offsets handles TP-scaled addressing. else: for i in range(len(self.block_len_per_layer)): assert ( From ded791c687eab88c0e7c92f87c648e2e343d2064 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 14:48:13 +0800 Subject: [PATCH 24/36] clean up Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 62 ++++++++----------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 312f0470711d..8f4ea7830243 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -965,10 +965,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) assert len(self.block_len_per_layer) == len(seen_base_addresses) # Expand NIXL registration size for dual-purpose regions. KDA - # registers first with its stride (542720 bytes/block), but the - # physical allocation uses max(KDA, MLA) = 552960 bytes/block. - # Without this, the registered region is too small for FA descriptors - # that address MLA data at the larger stride. + # registers first with its stride, but the physical allocation + # uses max(KDA, MLA). Without this, the registered region is + # too small for FA descriptors that address MLA data. for idx in range(len(caches_data)): if ( idx < len(self._is_ssm_region) @@ -1172,13 +1171,12 @@ def _build_fa_local( # Dual-purpose regions (HMA) get both FA and Mamba descs. if i < len(self._is_attn_region) and not self._is_attn_region[i]: continue - # For dual-purpose regions (HMA), block_len_per_layer stores the - # KDA/SSM stride. FA descriptors must use the attention spec's - # stride instead so they address MLA data correctly. - # MLA stride is TP-independent (num_kv_heads=1, head_dim - # constant), so we skip block_size_ratio scaling entirely. - if i in self._attn_block_len: - attn_stride = self._attn_block_len[i] + # For dual-purpose HMA regions, block_len_per_layer stores the + # KDA/SSM stride; FA descriptors must use the attention spec's + # stride so they address MLA data correctly. MLA stride is + # TP-independent (num_kv_heads=1), so skip block_size_ratio. + attn_stride = self._attn_block_len.get(i) + if attn_stride is not None: if self.transfer_topo.virtually_split_kv_in_blocks: kv_block_len = attn_stride // 2 else: @@ -1201,8 +1199,8 @@ def _build_fa_local( # Separate and interleave K/V regions to maintain the same # descs ordering. This is needed for selecting contiguous heads # when split across TP ranks. - if i in self._attn_block_len: - second_split = self._attn_block_len[i] // 2 + if attn_stride is not None: + second_split = attn_stride // 2 else: second_split = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=False, mamba_view=False @@ -1234,58 +1232,48 @@ def _build_fa_remote( # Dual-purpose regions (HMA) get both FA and Mamba descs. if i < len(self._is_attn_region) and not self._is_attn_region[i]: continue - # Read our whole local region size from remote.. - # For dual-purpose regions, use the attention spec's stride instead - # of block_len_per_layer (which stores KDA's stride). - # MLA stride is TP-independent, so skip block_size_ratio scaling. - if i in self._attn_block_len: - attn_stride = self._attn_block_len[i] + # For dual-purpose HMA regions, use the attention spec's stride + # instead of block_len_per_layer (which stores KDA's stride). + # MLA stride is TP-independent (num_kv_heads=1), so skip + # block_size_ratio scaling — local and remote MLA strides match. + attn_stride = self._attn_block_len.get(i) + if attn_stride is not None: if self.transfer_topo.virtually_split_kv_in_blocks: local_block_len = attn_stride // 2 else: local_block_len = attn_stride remote_kv_block_len = local_block_len + page_size = attn_stride else: local_block_len = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=True, mamba_view=False ) remote_kv_block_len = local_block_len // block_size_ratio if block_size_ratio > 1: - # ..using remote kv_block_len as transfer unit local_block_len = remote_kv_block_len + page_size = nixl_agent_meta.block_lens[i] local_block_len = local_block_len // num_attn_reads rank_offset = plan.rank_offset_factor * remote_kv_block_len - # For dual-purpose regions, use the attention stride as page size - # so descriptors step through remote memory at MLA's stride. - # MLA stride is TP-independent (num_kv_heads=1, head_dim constant), - # so the local _attn_block_len equals the remote's MLA stride. - if i in self._attn_block_len: - page_size = self._attn_block_len[i] - else: - page_size = nixl_agent_meta.block_lens[i] for block_id in range(num_blocks): block_offset = block_id * page_size - # For each block, grab the kv heads chunk belonging to current local - # tp rank of size local_block_len. + # For each block, grab the kv heads chunk belonging to current + # local tp rank of size local_block_len. addr = base_addr + block_offset + rank_offset result.append((addr, local_block_len, nixl_agent_meta.device_id)) if self.transfer_topo.virtually_split_kv_in_blocks: # With FlashInfer index V separately to allow head splitting. - if i in self._attn_block_len: - second_split = self._attn_block_len[i] // 2 + if attn_stride is not None: + second_split = attn_stride // 2 + v_stride = attn_stride else: second_split = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=False, mamba_view=False ) + v_stride = nixl_agent_meta.block_lens[i] second_split = second_split // num_attn_reads - v_stride = ( - self._attn_block_len[i] - if i in self._attn_block_len - else nixl_agent_meta.block_lens[i] - ) for block_id in range(num_blocks): block_offset = block_id * page_size addr = base_addr + block_offset + rank_offset From 8115e25ac5c07bda6181c17a4aa2ddfc84f5bab5 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 15:01:30 +0800 Subject: [PATCH 25/36] add more unit test Signed-off-by: JaredforReal --- .../unit/test_nixl_connector_hma.py | 233 +++++++++++++++++- 1 file changed, 232 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 6e399db7b140..a8c9b866d6de 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -3,7 +3,7 @@ """Unit tests for NixlConnectorScheduler with HMA and Mamba N-1 prefill.""" import gc -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest import torch @@ -1081,3 +1081,234 @@ def test_logical_to_remote_kernel_block_ids( assert list(result) == expected_kernel_block_ids, ( f"Expected {expected_kernel_block_ids}, got {result}" ) + + +# ── Dual-purpose HMA region tests ─────────────────────────────────────── + + +def _make_mock_worker_for_desc(**overrides): + """Build a mock NixlConnectorWorker with attrs for descriptor tests.""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import ( + NixlConnectorWorker, + ) + + worker = object.__new__(NixlConnectorWorker) + defaults = { + "num_blocks": 4, + "_logical_num_blocks": 4, + "_physical_blocks_per_logical_kv_block": 1, + "device_id": 0, + "block_len_per_layer": [], + "_is_ssm_region": [], + "_is_attn_region": [], + "_attn_block_len": {}, + "_has_mamba": True, + "num_regions": 0, + "_group_spec_types": (), + "_conv_decomp": None, + "_mamba_ssm_size": (0, 0), + } + for k, v in defaults.items(): + setattr(worker, k, overrides.get(k, v)) + + worker.transfer_topo = MagicMock() + worker.transfer_topo.virtually_split_kv_in_blocks = False + + return worker + + +def _make_mock_nixl_meta( + base_addrs, block_lens, num_blocks=4, device_id=0, ssm_sizes=(96, 64) +): + """Build a mock NixlAgentMetadata.""" + meta = MagicMock() + meta.kv_caches_base_addr = base_addrs + meta.block_lens = block_lens + meta.num_blocks = num_blocks + meta.device_id = device_id + meta.ssm_sizes = ssm_sizes + meta.physical_blocks_per_logical_kv_block = 1 + return meta + + +class TestBuildFaLocalDualPurpose: + """Tests for _build_fa_local with dual-purpose HMA regions.""" + + @pytest.mark.cpu_test + def test_dual_purpose_uses_mla_stride(self): + """Dual-purpose regions use _attn_block_len (MLA) stride, not KDA.""" + worker = _make_mock_worker_for_desc( + block_len_per_layer=[200], # KDA stride + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, # MLA stride + num_blocks=3, + ) + result = worker._build_fa_local([0x2000], block_size_ratio=1) + + assert len(result) == 3 + for i, (addr, size, dev) in enumerate(result): + assert addr == 0x2000 + i * 256 # MLA stride, not KDA 200 + assert size == 256 + + @pytest.mark.cpu_test + def test_skips_ssm_only_regions(self): + """Pure SSM regions (not attn) are skipped entirely.""" + worker = _make_mock_worker_for_desc( + block_len_per_layer=[200, 200], + _is_ssm_region=[True, True], + _is_attn_region=[True, False], # region 1 is SSM-only + _attn_block_len={0: 256}, + num_blocks=2, + ) + result = worker._build_fa_local([0x1000, 0x2000], block_size_ratio=1) + + assert len(result) == 2 + # Only region 0 (dual-purpose) generates FA descs + assert all(a < 0x2000 for a, _, _ in result) + + @pytest.mark.cpu_test + def test_no_block_size_ratio_for_dual_purpose(self): + """Dual-purpose: block_size_ratio does NOT scale MLA stride.""" + worker = _make_mock_worker_for_desc( + block_len_per_layer=[200], + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, + num_blocks=2, + ) + result = worker._build_fa_local([0x1000], block_size_ratio=2) + + # MLA stride 256 unscaled (not 256//2=128) + assert result[0][1] == 256 + assert result[1][0] == 0x1000 + 256 + + @pytest.mark.cpu_test + def test_mixed_regions(self): + """Mix of dual-purpose, pure SSM, and pure attention regions.""" + worker = _make_mock_worker_for_desc( + block_len_per_layer=[200, 128, 300], + _is_ssm_region=[True, True, False], + _is_attn_region=[True, False, True], + _attn_block_len={0: 256}, # only region 0 is dual-purpose + num_blocks=2, + ) + result = worker._build_fa_local([0x1000, 0x2000, 0x3000], block_size_ratio=1) + + # Region 0 (dual-purpose, MLA stride 256) + Region 2 (pure attn, 300) + assert len(result) == 4 + assert result[0] == (0x1000, 256, 0) + assert result[1] == (0x1000 + 256, 256, 0) + assert result[2] == (0x3000, 300, 0) + assert result[3] == (0x3000 + 300, 300, 0) + + +class TestBuildFaRemoteDualPurpose: + """Tests for _build_fa_remote with dual-purpose HMA regions.""" + + @pytest.mark.cpu_test + def test_dual_purpose_uses_local_mla_stride(self): + """Remote FA descs for dual-purpose use local MLA stride.""" + worker = _make_mock_worker_for_desc( + block_len_per_layer=[200], + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, + ) + plan = MagicMock() + plan.source_ranks_per_group = (MagicMock(),) + plan.source_ranks_per_group[0].__len__ = MagicMock(return_value=1) + plan.rank_offset_factor = 0 + meta = _make_mock_nixl_meta(base_addrs=[0x5000], block_lens=[200], num_blocks=3) + + result = worker._build_fa_remote(plan, meta, block_size_ratio=1) + + assert len(result) == 3 + # page_size = _attn_block_len[0] = 256 (not remote's 200) + assert result[0] == (0x5000, 256, 0) + assert result[1] == (0x5000 + 256, 256, 0) + assert result[2] == (0x5000 + 512, 256, 0) + + @pytest.mark.cpu_test + def test_no_ratio_scaling_for_dual_purpose(self): + """Dual-purpose: block_size_ratio doesn't scale MLA stride.""" + worker = _make_mock_worker_for_desc( + block_len_per_layer=[200], + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, + ) + plan = MagicMock() + plan.source_ranks_per_group = (MagicMock(),) + plan.source_ranks_per_group[0].__len__ = MagicMock(return_value=1) + plan.rank_offset_factor = 0 + meta = _make_mock_nixl_meta(base_addrs=[0x5000], block_lens=[100], num_blocks=2) + + result = worker._build_fa_remote(plan, meta, block_size_ratio=2) + + for addr, size, dev in result: + assert size == 256 # unscaled MLA stride + + +class TestNumRegionsDualPurpose: + """Tests for num_regions = sum(_is_attn_region).""" + + @pytest.mark.cpu_test + def test_pure_gdn(self): + """Qwen3.5 GDN: all SSM → num_regions = 0.""" + assert sum([False, False, False]) == 0 + + @pytest.mark.cpu_test + def test_pure_mla(self): + """Pure MLA: all attn → num_regions = N.""" + assert sum([True, True, True]) == 3 + + @pytest.mark.cpu_test + def test_kimilinear_dual_purpose(self): + """KimiLinear: 7 dual-purpose + 13 SSM-only → num_regions = 7. + + Old formula (len - sum(_is_ssm_region)) = 20 - 20 = 0 missed + the 7 dual-purpose regions. New formula correctly counts them. + """ + _is_attn = [True] * 7 + [False] * 13 + _is_ssm = [True] * 7 + [True] * 13 + + old_formula = len(_is_attn) - sum(_is_ssm) # 0 (wrong) + new_formula = sum(_is_attn) # 7 (correct) + + assert old_formula == 0 + assert new_formula == 7 + + +class TestNonHMARegression: + """Verify non-HMA models (Qwen3.5 GDN) are unaffected.""" + + @pytest.mark.cpu_test + def test_qwen35_gdn_skips_all_fa(self): + """Qwen3.5 GDN: all SSM, no attn → _build_fa_local produces 0.""" + worker = _make_mock_worker_for_desc( + block_len_per_layer=[128], + _is_ssm_region=[True], + _is_attn_region=[False], + _attn_block_len={}, + num_blocks=4, + ) + result = worker._build_fa_local([0x1000], block_size_ratio=1) + assert len(result) == 0 + + @pytest.mark.cpu_test + def test_pure_mla_no_dual_purpose(self): + """Pure MLA: _attn_block_len empty → standard path.""" + worker = _make_mock_worker_for_desc( + block_len_per_layer=[512], + _is_ssm_region=[False], + _is_attn_region=[True], + _attn_block_len={}, # empty → .get(i) returns None → else + _has_mamba=False, + num_blocks=3, + ) + result = worker._build_fa_local([0x1000], block_size_ratio=1) + + assert len(result) == 3 + for i, (addr, size, dev) in enumerate(result): + assert addr == 0x1000 + i * 512 From abc27a144e7eb1a784aa66c3613ad2aa397b1f3a Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 15:05:19 +0800 Subject: [PATCH 26/36] fix unit test Signed-off-by: JaredforReal --- tests/v1/kv_connector/unit/test_nixl_connector_hma.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index a8c9b866d6de..b792ccec5ff8 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -583,6 +583,9 @@ def _make_mock_worker_for_desc_ids( worker._has_mamba = has_mamba worker._group_spec_types = group_spec_types worker.block_len_per_layer = block_len_per_layer or [100] + worker._is_ssm_region = [] + worker._is_attn_region = [] + worker._attn_block_len = {} worker._compute_desc_ids = NixlConnectorWorker._compute_desc_ids.__get__( worker, NixlConnectorWorker ) @@ -1209,11 +1212,14 @@ class TestBuildFaRemoteDualPurpose: @pytest.mark.cpu_test def test_dual_purpose_uses_local_mla_stride(self): """Remote FA descs for dual-purpose use local MLA stride.""" + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + worker = _make_mock_worker_for_desc( block_len_per_layer=[200], _is_ssm_region=[True], _is_attn_region=[True], _attn_block_len={0: 256}, + _group_spec_types=(FullAttentionSpec, MambaSpec), ) plan = MagicMock() plan.source_ranks_per_group = (MagicMock(),) @@ -1232,11 +1238,14 @@ def test_dual_purpose_uses_local_mla_stride(self): @pytest.mark.cpu_test def test_no_ratio_scaling_for_dual_purpose(self): """Dual-purpose: block_size_ratio doesn't scale MLA stride.""" + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + worker = _make_mock_worker_for_desc( block_len_per_layer=[200], _is_ssm_region=[True], _is_attn_region=[True], _attn_block_len={0: 256}, + _group_spec_types=(FullAttentionSpec, MambaSpec), ) plan = MagicMock() plan.source_ranks_per_group = (MagicMock(),) From 5240efcf4c5a035eb4364f58720488452803f7eb Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 15:10:01 +0800 Subject: [PATCH 27/36] fix helper Signed-off-by: JaredforReal --- tests/v1/kv_connector/unit/test_nixl_connector_hma.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index b792ccec5ff8..c09d982a3afc 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -583,8 +583,13 @@ def _make_mock_worker_for_desc_ids( worker._has_mamba = has_mamba worker._group_spec_types = group_spec_types worker.block_len_per_layer = block_len_per_layer or [100] - worker._is_ssm_region = [] - worker._is_attn_region = [] + # Derive _is_ssm_region from group_spec_types: one entry per + # unique base-address region. With _cross_layers_blocks each group + # maps to one region, so we use a simple per-group flag. + from vllm.v1.kv_cache_interface import MambaSpec + + worker._is_ssm_region = [issubclass(t, MambaSpec) for t in group_spec_types] + worker._is_attn_region = [not s for s in worker._is_ssm_region] worker._attn_block_len = {} worker._compute_desc_ids = NixlConnectorWorker._compute_desc_ids.__get__( worker, NixlConnectorWorker From 31e960a65c43d3cab32a9c2bdaa5cf8dcb60f930 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 15:13:12 +0800 Subject: [PATCH 28/36] revert utils Signed-off-by: JaredforReal --- .../kv_transfer/kv_connector/utils.py | 22 +++++-------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index f21b6bb07e82..0ab694b7e73d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -533,24 +533,12 @@ def tp_ratio(self, remote_tp_size: int) -> int: return -(remote_tp_size // self.tp_size) def block_size_ratio(self, remote_block_size: int) -> int: - """Calculate the block size ratio between local and remote. - - Positive when local >= remote (local blocks are larger). - Negative when remote > local (remote blocks are larger). - """ - if self.block_size == remote_block_size: - return 1 - if self.block_size > remote_block_size: - assert self.block_size % remote_block_size == 0, ( - f"Local block size {self.block_size} is not divisible " - f"by remote block size {remote_block_size}." - ) - return self.block_size // remote_block_size - assert remote_block_size % self.block_size == 0, ( - f"Remote block size {remote_block_size} is not divisible " - f"by local block size {self.block_size}." + """Calculate the block size ratio between local and remote.""" + assert self.block_size % remote_block_size == 0, ( + f"Local block size {self.block_size} is not divisible " + f"by remote block size {remote_block_size} or vice versa." ) - return -(remote_block_size // self.block_size) + return self.block_size // remote_block_size def is_kv_replicated( self, remote_engine_id: EngineId, remote_pp_rank: int = 0 From 6f73e044323c6a07aba12813e671ac9385328a82 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 18:40:48 +0800 Subject: [PATCH 29/36] fix qwen heter tp Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 51 ++++++++++++++----- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 8f4ea7830243..c6a4c55cc6f6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1234,16 +1234,30 @@ def _build_fa_remote( continue # For dual-purpose HMA regions, use the attention spec's stride # instead of block_len_per_layer (which stores KDA's stride). - # MLA stride is TP-independent (num_kv_heads=1), so skip - # block_size_ratio scaling — local and remote MLA strides match. attn_stride = self._attn_block_len.get(i) if attn_stride is not None: - if self.transfer_topo.virtually_split_kv_in_blocks: - local_block_len = attn_stride // 2 + if self.use_mla: + # MLA stride is TP-independent (num_kv_heads=1), + # so local attn_stride equals remote's MLA stride. + if self.transfer_topo.virtually_split_kv_in_blocks: + local_block_len = attn_stride // 2 + else: + local_block_len = attn_stride + remote_kv_block_len = local_block_len + page_size = attn_stride else: - local_block_len = attn_stride - remote_kv_block_len = local_block_len - page_size = attn_stride + # Standard attention stride is TP-dependent + # (num_kv_heads scales with TP). Use remote's + # block_len for stepping through remote memory and + # apply block_size_ratio for correct sizes. + if self.transfer_topo.virtually_split_kv_in_blocks: + local_block_len = attn_stride // 2 + else: + local_block_len = attn_stride + remote_kv_block_len = local_block_len // block_size_ratio + if block_size_ratio > 1: + local_block_len = remote_kv_block_len + page_size = nixl_agent_meta.block_lens[i] else: local_block_len = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=True, mamba_view=False @@ -1266,8 +1280,14 @@ def _build_fa_remote( if self.transfer_topo.virtually_split_kv_in_blocks: # With FlashInfer index V separately to allow head splitting. if attn_stride is not None: - second_split = attn_stride // 2 - v_stride = attn_stride + if self.use_mla: + second_split = attn_stride // 2 + v_stride = attn_stride + else: + # Standard attention: V size and stride must + # match remote's physical page layout. + second_split = local_block_len + v_stride = page_size else: second_split = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=False, mamba_view=False @@ -1590,11 +1610,14 @@ def _validate_remote_agent_handshake( and block_size_ratio != 1 and not (self.use_mla and self._has_mamba) ): - # For hybrid MLA+SSM models, block_size_ratio reflects SSM - # dimension scaling across different TP sizes. MLA attention - # stride is TP-independent (replicated KV), so FA descriptors - # for dual-purpose regions are safe. SSM descriptors use their - # own hetero-TP-aware remote_conv_offsets. + # For hybrid models with HMA, block_size_ratio != 1 means + # token-level block sizes differ. MLA+SSM models are + # exempted because MLA stride is TP-independent. + # Standard attention + SSM models (e.g. Qwen GDN) are + # handled by _build_fa_remote which applies block_size_ratio + # to compute remote attention dimensions. + # SSM descriptors use their own hetero-TP-aware + # remote_conv_offsets. raise AssertionError("HMA does not support different remote block size yet") kv_cache_layout = ( self.kv_cache_layout From ee1caa1b29d9f69bdb1e27862405fb2171568a5d Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 19:02:16 +0800 Subject: [PATCH 30/36] fix qwen heter tp mismatch Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 48 +++++++++++++++---- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index c6a4c55cc6f6..5a13ffbb857f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1500,11 +1500,26 @@ def add_remote_agent( # we only do this once per remote tp_size (replica-friendly). self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] - for handle_data in self._build_local_splits_from_plan( - plan, - self.src_blocks_data, - self.num_descs, + for split_idx, handle_data in enumerate( + self._build_local_splits_from_plan( + plan, + self.src_blocks_data, + self.num_descs, + ) ): + # Debug: log split handle mamba descriptor sizes + if self._has_mamba and abs(tp_ratio) != 1: + fa_count = self.num_descs + mamba_part = handle_data[fa_count:] + logger.info( + "HETEROTP local split[%d]: total=%d, fa=%d, " + "mamba=%d, first_mamba_sizes=%s", + split_idx, + len(handle_data), + fa_count, + len(mamba_part), + [s for _, s, _ in mamba_part[:8]], + ) descs = self.nixl_wrapper.get_xfer_descs( handle_data, self.nixl_memory_type ) @@ -1536,13 +1551,28 @@ def add_remote_agent( engine_id, remote_tp_rank, ) - blocks_data.extend( - self._build_mamba_remote( - nixl_agent_meta, + mamba_remote = self._build_mamba_remote( + nixl_agent_meta, + tp_ratio, + transfer_info, + ) + # Debug: log mamba descriptor sizes for hetero-TP diagnosis + if abs(tp_ratio) != 1 and mamba_remote: + fa_count = len(blocks_data) + logger.info( + "HETEROTP remote descs: fa=%d, mamba=%d, " + "first_mamba_sizes=%s, " + "ssm_sizes_remote=%s, ssm_sizes_local=%s, " + "tp_ratio=%s, block_size_ratio=%s", + fa_count, + len(mamba_remote), + [s for _, s, _ in mamba_remote[:8]], + nixl_agent_meta.ssm_sizes, + self._mamba_ssm_size, tp_ratio, - transfer_info, + block_size_ratio, ) - ) + blocks_data.extend(mamba_remote) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) self.dst_xfer_side_handles[engine_id][remote_tp_rank] = ( From fae2b32e1a9e411bf44a0b89880e4052c6786242 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 19:14:00 +0800 Subject: [PATCH 31/36] more Signed-off-by: JaredforReal --- .../kv_transfer/kv_connector/v1/nixl/worker.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 5a13ffbb857f..afcadabcf29d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1248,15 +1248,17 @@ def _build_fa_remote( else: # Standard attention stride is TP-dependent # (num_kv_heads scales with TP). Use remote's - # block_len for stepping through remote memory and - # apply block_size_ratio for correct sizes. + # block_len for stepping through remote memory. + # Do NOT apply block_size_ratio to descriptor SIZE — + # the split handle mechanism (fa_num_splits) already + # reduces sizes for heterogeneous TP. Applying + # block_size_ratio here would double-reduce and cause + # local/remote size mismatches at makeXferReq. if self.transfer_topo.virtually_split_kv_in_blocks: local_block_len = attn_stride // 2 else: local_block_len = attn_stride - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - local_block_len = remote_kv_block_len + remote_kv_block_len = local_block_len page_size = nixl_agent_meta.block_lens[i] else: local_block_len = self.get_backend_aware_kv_block_len( From deb070af8d4208051a48421a875b7bf8fb78b548 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 20:16:31 +0800 Subject: [PATCH 32/36] fix Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index afcadabcf29d..696501a28614 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1439,14 +1439,25 @@ def add_remote_agent( # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| # local origin:| 0| 1| 8| 12| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| - # Compute block_size_ratio from actual byte-per-block values so that - # heterogeneous TP works even when block_size carries byte values - # (e.g. hybrid MLA+GDN models where block_size differs across TP - # configs). _build_fa_remote already uses nixl_agent_meta.block_lens - # directly, so this ratio is only needed for handler registration and - # descriptor-ID computation. + # Compute block_size_ratio. + # + # For dual-purpose HMA regions (MLA+GDN sharing one backing tensor), + # block_len_per_layer stores the KDA/SSM stride which differs from the + # attention stride. The token-level block_size is the same across TP, + # so transfer_topo.block_size_ratio returns 1 even though byte-level + # strides differ. Use byte-level comparison only for these models so + # that mamba descriptor building and desc-ID computation see the + # correct ratio. + # + # For standard attention (Qwen, etc.), the byte-level difference comes + # purely from KV head count, which num_attn_reads and the split-handle + # mechanism already handle correctly. Using byte-level ratio here + # would cause double reduction in _build_fa_remote (once by + # block_size_ratio and once by num_attn_reads), producing remote + # descriptors that are half the size of local split descriptors. if ( - self.block_len_per_layer + self._attn_block_len # Only for dual-purpose HMA regions + and self.block_len_per_layer and nixl_agent_meta.block_lens and self.block_len_per_layer[0] != nixl_agent_meta.block_lens[0] ): From 824eb1797f792f6adf1f4ab7a4a6d7fe7a7f79ce Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 23:09:03 +0800 Subject: [PATCH 33/36] use mla Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 63 +++++++------------ 1 file changed, 23 insertions(+), 40 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 696501a28614..516caeca3b8e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1171,11 +1171,13 @@ def _build_fa_local( # Dual-purpose regions (HMA) get both FA and Mamba descs. if i < len(self._is_attn_region) and not self._is_attn_region[i]: continue - # For dual-purpose HMA regions, block_len_per_layer stores the - # KDA/SSM stride; FA descriptors must use the attention spec's - # stride so they address MLA data correctly. MLA stride is - # TP-independent (num_kv_heads=1), so skip block_size_ratio. - attn_stride = self._attn_block_len.get(i) + # For dual-purpose HMA regions with MLA (num_kv_heads=1, stride + # is TP-independent), use the attention spec's stride instead of + # block_len_per_layer (which stores KDA/SSM stride). + # For standard attention (e.g. Qwen GQA), stride is TP-dependent, + # so we must fall through to the original path that correctly + # applies block_size_ratio for heterogeneous TP. + attn_stride = self._attn_block_len.get(i) if self.use_mla else None if attn_stride is not None: if self.transfer_topo.virtually_split_kv_in_blocks: kv_block_len = attn_stride // 2 @@ -1232,34 +1234,20 @@ def _build_fa_remote( # Dual-purpose regions (HMA) get both FA and Mamba descs. if i < len(self._is_attn_region) and not self._is_attn_region[i]: continue - # For dual-purpose HMA regions, use the attention spec's stride - # instead of block_len_per_layer (which stores KDA's stride). - attn_stride = self._attn_block_len.get(i) + # For dual-purpose HMA regions with MLA (stride is TP-independent), + # use the attention spec's stride. For standard attention (Qwen GQA), + # stride is TP-dependent so fall through to the original path that + # applies block_size_ratio correctly for heterogeneous TP. + attn_stride = self._attn_block_len.get(i) if self.use_mla else None if attn_stride is not None: - if self.use_mla: - # MLA stride is TP-independent (num_kv_heads=1), - # so local attn_stride equals remote's MLA stride. - if self.transfer_topo.virtually_split_kv_in_blocks: - local_block_len = attn_stride // 2 - else: - local_block_len = attn_stride - remote_kv_block_len = local_block_len - page_size = attn_stride + # MLA stride is TP-independent (num_kv_heads=1), + # so local attn_stride equals remote's MLA stride. + if self.transfer_topo.virtually_split_kv_in_blocks: + local_block_len = attn_stride // 2 else: - # Standard attention stride is TP-dependent - # (num_kv_heads scales with TP). Use remote's - # block_len for stepping through remote memory. - # Do NOT apply block_size_ratio to descriptor SIZE — - # the split handle mechanism (fa_num_splits) already - # reduces sizes for heterogeneous TP. Applying - # block_size_ratio here would double-reduce and cause - # local/remote size mismatches at makeXferReq. - if self.transfer_topo.virtually_split_kv_in_blocks: - local_block_len = attn_stride // 2 - else: - local_block_len = attn_stride - remote_kv_block_len = local_block_len - page_size = nixl_agent_meta.block_lens[i] + local_block_len = attn_stride + remote_kv_block_len = local_block_len + page_size = attn_stride else: local_block_len = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=True, mamba_view=False @@ -1282,14 +1270,8 @@ def _build_fa_remote( if self.transfer_topo.virtually_split_kv_in_blocks: # With FlashInfer index V separately to allow head splitting. if attn_stride is not None: - if self.use_mla: - second_split = attn_stride // 2 - v_stride = attn_stride - else: - # Standard attention: V size and stride must - # match remote's physical page layout. - second_split = local_block_len - v_stride = page_size + second_split = attn_stride // 2 + v_stride = attn_stride else: second_split = self.get_backend_aware_kv_block_len( layer_idx=i, first_split=False, mamba_view=False @@ -1456,7 +1438,8 @@ def add_remote_agent( # block_size_ratio and once by num_attn_reads), producing remote # descriptors that are half the size of local split descriptors. if ( - self._attn_block_len # Only for dual-purpose HMA regions + self.use_mla # Only for MLA where stride is TP-independent + and self._attn_block_len # Only for dual-purpose HMA regions and self.block_len_per_layer and nixl_agent_meta.block_lens and self.block_len_per_layer[0] != nixl_agent_meta.block_lens[0] From 4abbd1bd26109b04bd6da2d77f7eb01ea030031d Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 23:35:18 +0800 Subject: [PATCH 34/36] clean up Signed-off-by: JaredforReal --- .../kv_connector/v1/nixl/worker.py | 96 +++++-------------- 1 file changed, 23 insertions(+), 73 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 516caeca3b8e..5bd137e9ac8d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -101,6 +101,7 @@ def _compute_desc_ids( if block_size_ratio is not None: num_blocks = int(num_blocks * block_size_ratio) num_fa_descs = num_fa_regions * num_blocks + # All-attention fast path: single vectorized broadcast. if num_ssm_regions == 0: # NOTE (NickLucche) With HMA, every kv group has the same number of layers @@ -146,8 +147,7 @@ def _compute_desc_ids( f"Unknown spec type {self._group_spec_types[i]} at index {i}" ) - result = np.concatenate(all_descs) - return result + return np.concatenate(all_descs) def _build_local_splits_from_plan( self, @@ -919,8 +919,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._attn_block_len[idx] = physical_page_size continue seen_base_addresses.append(base_addr) - is_ssm = isinstance(layer_spec, MambaSpec) - is_attn = not is_ssm self._is_ssm_region.append(is_ssm) self._is_attn_region.append(is_attn) # Only record non-Mamba page sizes. @@ -1171,12 +1169,8 @@ def _build_fa_local( # Dual-purpose regions (HMA) get both FA and Mamba descs. if i < len(self._is_attn_region) and not self._is_attn_region[i]: continue - # For dual-purpose HMA regions with MLA (num_kv_heads=1, stride - # is TP-independent), use the attention spec's stride instead of - # block_len_per_layer (which stores KDA/SSM stride). - # For standard attention (e.g. Qwen GQA), stride is TP-dependent, - # so we must fall through to the original path that correctly - # applies block_size_ratio for heterogeneous TP. + # Dual-purpose HMA regions: use MLA's TP-independent stride. + # Standard attention falls through to block_size_ratio path. attn_stride = self._attn_block_len.get(i) if self.use_mla else None if attn_stride is not None: if self.transfer_topo.virtually_split_kv_in_blocks: @@ -1234,10 +1228,7 @@ def _build_fa_remote( # Dual-purpose regions (HMA) get both FA and Mamba descs. if i < len(self._is_attn_region) and not self._is_attn_region[i]: continue - # For dual-purpose HMA regions with MLA (stride is TP-independent), - # use the attention spec's stride. For standard attention (Qwen GQA), - # stride is TP-dependent so fall through to the original path that - # applies block_size_ratio correctly for heterogeneous TP. + # MLA: TP-independent stride; standard attention: use block_size_ratio. attn_stride = self._attn_block_len.get(i) if self.use_mla else None if attn_stride is not None: # MLA stride is TP-independent (num_kv_heads=1), @@ -1325,6 +1316,7 @@ def register_local_xfer_handler( blocks_data.extend( self._build_mamba_local(local_base_addresses, block_size_ratio) ) + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs), blocks_data @@ -1422,21 +1414,10 @@ def add_remote_agent( # local origin:| 0| 1| 8| 12| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| # Compute block_size_ratio. - # - # For dual-purpose HMA regions (MLA+GDN sharing one backing tensor), - # block_len_per_layer stores the KDA/SSM stride which differs from the - # attention stride. The token-level block_size is the same across TP, - # so transfer_topo.block_size_ratio returns 1 even though byte-level - # strides differ. Use byte-level comparison only for these models so - # that mamba descriptor building and desc-ID computation see the - # correct ratio. - # - # For standard attention (Qwen, etc.), the byte-level difference comes - # purely from KV head count, which num_attn_reads and the split-handle - # mechanism already handle correctly. Using byte-level ratio here - # would cause double reduction in _build_fa_remote (once by - # block_size_ratio and once by num_attn_reads), producing remote - # descriptors that are half the size of local split descriptors. + # For dual-purpose HMA regions, byte-level strides differ across TP + # even when token-level block_size is the same. Compute byte-level + # ratio for MLA models. Standard attention models use the existing + # num_attn_reads mechanism instead. if ( self.use_mla # Only for MLA where stride is TP-independent and self._attn_block_len # Only for dual-purpose HMA regions @@ -1496,26 +1477,11 @@ def add_remote_agent( # we only do this once per remote tp_size (replica-friendly). self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] - for split_idx, handle_data in enumerate( - self._build_local_splits_from_plan( - plan, - self.src_blocks_data, - self.num_descs, - ) + for handle_data in self._build_local_splits_from_plan( + plan, + self.src_blocks_data, + self.num_descs, ): - # Debug: log split handle mamba descriptor sizes - if self._has_mamba and abs(tp_ratio) != 1: - fa_count = self.num_descs - mamba_part = handle_data[fa_count:] - logger.info( - "HETEROTP local split[%d]: total=%d, fa=%d, " - "mamba=%d, first_mamba_sizes=%s", - split_idx, - len(handle_data), - fa_count, - len(mamba_part), - [s for _, s, _ in mamba_part[:8]], - ) descs = self.nixl_wrapper.get_xfer_descs( handle_data, self.nixl_memory_type ) @@ -1547,28 +1513,13 @@ def add_remote_agent( engine_id, remote_tp_rank, ) - mamba_remote = self._build_mamba_remote( - nixl_agent_meta, - tp_ratio, - transfer_info, - ) - # Debug: log mamba descriptor sizes for hetero-TP diagnosis - if abs(tp_ratio) != 1 and mamba_remote: - fa_count = len(blocks_data) - logger.info( - "HETEROTP remote descs: fa=%d, mamba=%d, " - "first_mamba_sizes=%s, " - "ssm_sizes_remote=%s, ssm_sizes_local=%s, " - "tp_ratio=%s, block_size_ratio=%s", - fa_count, - len(mamba_remote), - [s for _, s, _ in mamba_remote[:8]], - nixl_agent_meta.ssm_sizes, - self._mamba_ssm_size, + blocks_data.extend( + self._build_mamba_remote( + nixl_agent_meta, tp_ratio, - block_size_ratio, + transfer_info, ) - blocks_data.extend(mamba_remote) + ) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) self.dst_xfer_side_handles[engine_id][remote_tp_rank] = ( @@ -1687,7 +1638,8 @@ def _validate_remote_agent_handshake( # Heterogeneous TP requires head-splitting, which only works with # HND layout. MLA and replicated-KV cases don't split on heads. - # Mamba doesn't support heterogeneous TP. + # The attention component of hybrid models still requires HND for + # head-dimension splitting under heterogeneous TP. if ( abs(tp_ratio) != 1 and not self.use_mla @@ -2299,10 +2251,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): req_id, ) # Get side handles. - # For hybrid MLA+GDN with tp_ratio < 0, SSM needs split handles to - # assemble data from multiple remote ranks. MLA attention reads - # use the full region (replicated, single rank) but the split - # handle applies offset 0 + full chunk for FA when fa_num_splits=1. + # Hybrid MLA+GDN: SSM needs split handles for multi-rank assembly. if tp_ratio < 0 and (not self.use_mla or self._has_mamba): # Remote tp_size > local tp_size: we must perform multiple # reads. Get the memory chunk onto which we will write to. @@ -2463,6 +2412,7 @@ def _read_blocks( ) assert len(local_block_descs_ids) == len(remote_block_descs_ids) + # Prepare transfer with Nixl. handle = None try: From bf6044f1adf26bc563e759d47c8dbc5f1ca5a6d8 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 9 Jun 2026 23:36:19 +0800 Subject: [PATCH 35/36] more unit tests Signed-off-by: JaredforReal --- .../unit/test_nixl_connector_hma.py | 205 +++++++++++++++++- 1 file changed, 203 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index c09d982a3afc..57f5f37d48a5 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -1111,6 +1111,7 @@ def _make_mock_worker_for_desc(**overrides): "_is_attn_region": [], "_attn_block_len": {}, "_has_mamba": True, + "use_mla": False, "num_regions": 0, "_group_spec_types": (), "_conv_decomp": None, @@ -1140,12 +1141,13 @@ def _make_mock_nixl_meta( class TestBuildFaLocalDualPurpose: - """Tests for _build_fa_local with dual-purpose HMA regions.""" + """Tests for _build_fa_local with dual-purpose HMA regions (MLA models).""" @pytest.mark.cpu_test def test_dual_purpose_uses_mla_stride(self): """Dual-purpose regions use _attn_block_len (MLA) stride, not KDA.""" worker = _make_mock_worker_for_desc( + use_mla=True, block_len_per_layer=[200], # KDA stride _is_ssm_region=[True], _is_attn_region=[True], @@ -1163,6 +1165,7 @@ def test_dual_purpose_uses_mla_stride(self): def test_skips_ssm_only_regions(self): """Pure SSM regions (not attn) are skipped entirely.""" worker = _make_mock_worker_for_desc( + use_mla=True, block_len_per_layer=[200, 200], _is_ssm_region=[True, True], _is_attn_region=[True, False], # region 1 is SSM-only @@ -1179,6 +1182,7 @@ def test_skips_ssm_only_regions(self): def test_no_block_size_ratio_for_dual_purpose(self): """Dual-purpose: block_size_ratio does NOT scale MLA stride.""" worker = _make_mock_worker_for_desc( + use_mla=True, block_len_per_layer=[200], _is_ssm_region=[True], _is_attn_region=[True], @@ -1195,6 +1199,7 @@ def test_no_block_size_ratio_for_dual_purpose(self): def test_mixed_regions(self): """Mix of dual-purpose, pure SSM, and pure attention regions.""" worker = _make_mock_worker_for_desc( + use_mla=True, block_len_per_layer=[200, 128, 300], _is_ssm_region=[True, True, False], _is_attn_region=[True, False, True], @@ -1212,7 +1217,7 @@ def test_mixed_regions(self): class TestBuildFaRemoteDualPurpose: - """Tests for _build_fa_remote with dual-purpose HMA regions.""" + """Tests for _build_fa_remote with dual-purpose HMA regions (MLA models).""" @pytest.mark.cpu_test def test_dual_purpose_uses_local_mla_stride(self): @@ -1220,6 +1225,7 @@ def test_dual_purpose_uses_local_mla_stride(self): from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec worker = _make_mock_worker_for_desc( + use_mla=True, block_len_per_layer=[200], _is_ssm_region=[True], _is_attn_region=[True], @@ -1246,6 +1252,7 @@ def test_no_ratio_scaling_for_dual_purpose(self): from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec worker = _make_mock_worker_for_desc( + use_mla=True, block_len_per_layer=[200], _is_ssm_region=[True], _is_attn_region=[True], @@ -1326,3 +1333,197 @@ def test_pure_mla_no_dual_purpose(self): assert len(result) == 3 for i, (addr, size, dev) in enumerate(result): assert addr == 0x1000 + i * 512 + + +# ── Qwen heterogeneous TP regression tests ─────────────────────────────── + + +class TestQwenHeteroTPRegression: + """Verify Qwen (standard GQA, use_mla=False) heterogeneous TP is not + broken by the _attn_block_len mechanism introduced for KimiLinear. + + The root cause of the regression: _attn_block_len was populated for + Qwen HMA dual-purpose regions, causing _build_fa_local/_build_fa_remote + to use the MLA stride path which skips block_size_ratio scaling. + For standard attention, stride IS TP-dependent, so skipping the ratio + produces wrong descriptor sizes and 0 successful KV transfers. + + The fix: guard attn_stride with `self.use_mla`. + """ + + @pytest.mark.cpu_test + def test_qwen_hma_dual_purpose_ignores_attn_stride_local(self): + """Qwen HMA with dual-purpose regions: _build_fa_local must NOT + use _attn_block_len stride when use_mla=False. + + Without the use_mla guard, attn_stride=1024 would be used directly + (no block_size_ratio scaling), producing wrong descriptor sizes. + With the guard, the standard path applies block_size_ratio correctly. + """ + worker = _make_mock_worker_for_desc( + use_mla=False, + block_len_per_layer=[512], # SSM stride (local, TP=2) + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 1024}, # attention stride (local, TP=2) + num_blocks=2, + ) + # block_size_ratio=2: local blocks are 2× remote blocks + result = worker._build_fa_local([0x1000], block_size_ratio=2) + + # Standard path: block_len_per_layer[0] // block_size_ratio = 512//2 = 256 + # num_blocks = 2 * 2 = 4 descriptors with stride 256 + assert len(result) == 4 + for i, (addr, size, dev) in enumerate(result): + assert size == 256 # 512 // 2, NOT 1024 (unscaled attn stride) + assert addr == 0x1000 + i * 256 + + @pytest.mark.cpu_test + def test_qwen_hma_dual_purpose_ignores_attn_stride_remote(self): + """Qwen HMA with dual-purpose regions: _build_fa_remote must NOT + use _attn_block_len stride when use_mla=False.""" + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + + worker = _make_mock_worker_for_desc( + use_mla=False, + block_len_per_layer=[512], # local SSM stride + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 1024}, # local attention stride + _group_spec_types=(FullAttentionSpec, MambaSpec), + ) + plan = MagicMock() + plan.source_ranks_per_group = (MagicMock(),) + plan.source_ranks_per_group[0].__len__ = MagicMock(return_value=1) + plan.rank_offset_factor = 0 + # Remote block_lens = P-side SSM stride (TP=4, smaller) + meta = _make_mock_nixl_meta(base_addrs=[0x5000], block_lens=[256], num_blocks=4) + + result = worker._build_fa_remote(plan, meta, block_size_ratio=1) + + # Standard path with block_size_ratio=1: + # get_backend_aware_kv_block_len → block_len_per_layer[0] = 512 + # remote_kv_block_len = 512 // 1 = 512 + # local_block_len = 512 // 1 (num_attn_reads) = 512 + # page_size = remote's block_lens[0] = 256 + assert len(result) == 4 + for i, (addr, size, dev) in enumerate(result): + assert addr == 0x5000 + i * 256 # steps by remote page_size + assert size == 512 + + @pytest.mark.cpu_test + def test_qwen_hetero_tp_local_applies_block_size_ratio(self): + """Qwen P4D2 (D-side, TP=2): local descriptors must scale by ratio. + + Without the fix, attn_stride (D's attention stride) was used without + //block_size_ratio, producing descriptors 2× the correct size. + """ + worker = _make_mock_worker_for_desc( + use_mla=False, + block_len_per_layer=[1024], # D-side SSM stride (large, TP=2) + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 2048}, # D-side attention stride (2×KV heads) + num_blocks=2, + ) + + result = worker._build_fa_local([0x1000], block_size_ratio=2) + + # num_blocks = 2 * 2 = 4 + # Standard path: stride = 1024 // 2 = 512, size = 1024 // 2 = 512 + # Without fix: stride = 2048 (wrong!), size = 2048 (wrong!) + assert len(result) == 4 + for i, (addr, size, dev) in enumerate(result): + assert size == 512 + assert addr == 0x1000 + i * 512 + + @pytest.mark.cpu_test + def test_qwen_homo_tp_unaffected(self): + """Qwen homogeneous TP: block_size_ratio=1, both paths produce + same result. This verifies no regression for the working case.""" + worker = _make_mock_worker_for_desc( + use_mla=False, + block_len_per_layer=[512], + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 1024}, + num_blocks=3, + ) + + result = worker._build_fa_local([0x1000], block_size_ratio=1) + + # Standard path with ratio=1: stride=512, size=512 + assert len(result) == 3 + for i, (addr, size, dev) in enumerate(result): + assert size == 512 + assert addr == 0x1000 + i * 512 + + +class TestKimiLinearDualPurpose: + """Verify KimiLinear (use_mla=True, KDA+MLA) dual-purpose regions + correctly use the MLA stride path with the use_mla guard.""" + + @pytest.mark.cpu_test + def test_mla_stride_path_activates_for_kimilinear(self): + """KimiLinear: use_mla=True → attn_stride path is taken.""" + worker = _make_mock_worker_for_desc( + use_mla=True, + block_len_per_layer=[200], # KDA stride + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, # MLA stride (TP-independent) + num_blocks=3, + ) + + result = worker._build_fa_local([0x1000], block_size_ratio=1) + + # MLA path: stride=256 (unscaled), size=256 + assert len(result) == 3 + for i, (addr, size, dev) in enumerate(result): + assert size == 256 # MLA stride, not KDA's 200 + assert addr == 0x1000 + i * 256 + + @pytest.mark.cpu_test + def test_mla_stride_unscaled_by_block_size_ratio(self): + """KimiLinear: MLA stride is TP-independent, block_size_ratio + must NOT be applied (MLA num_kv_heads=1 regardless of TP).""" + worker = _make_mock_worker_for_desc( + use_mla=True, + block_len_per_layer=[200], + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, + num_blocks=2, + ) + + result = worker._build_fa_local([0x1000], block_size_ratio=2) + + # MLA path: stride=256 (NOT 256//2=128) + assert result[0][1] == 256 + assert result[1][0] == 0x1000 + 256 + + @pytest.mark.cpu_test + def test_mla_remote_uses_attn_stride_as_page_size(self): + """KimiLinear remote: page_size = attn_stride (MLA, TP-independent).""" + from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + + worker = _make_mock_worker_for_desc( + use_mla=True, + block_len_per_layer=[200], + _is_ssm_region=[True], + _is_attn_region=[True], + _attn_block_len={0: 256}, + _group_spec_types=(FullAttentionSpec, MambaSpec), + ) + plan = MagicMock() + plan.source_ranks_per_group = (MagicMock(),) + plan.source_ranks_per_group[0].__len__ = MagicMock(return_value=1) + plan.rank_offset_factor = 0 + meta = _make_mock_nixl_meta(base_addrs=[0x5000], block_lens=[200], num_blocks=3) + + result = worker._build_fa_remote(plan, meta, block_size_ratio=1) + + # MLA path: page_size = attn_stride = 256 (not remote's 200) + assert result[0] == (0x5000, 256, 0) + assert result[1] == (0x5000 + 256, 256, 0) + assert result[2] == (0x5000 + 512, 256, 0) From 306f9bc7abc51ac80fdca2cdff82f221ada77005 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Thu, 11 Jun 2026 16:48:55 +0800 Subject: [PATCH 36/36] handle blocksize ratio in TransferTopology Signed-off-by: JaredforReal --- .../kv_transfer/kv_connector/utils.py | 15 ++++++--- .../kv_connector/v1/nixl/worker.py | 31 +++++++------------ 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 0ab694b7e73d..b4048af61ca9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -533,11 +533,16 @@ def tp_ratio(self, remote_tp_size: int) -> int: return -(remote_tp_size // self.tp_size) def block_size_ratio(self, remote_block_size: int) -> int: - """Calculate the block size ratio between local and remote.""" - assert self.block_size % remote_block_size == 0, ( - f"Local block size {self.block_size} is not divisible " - f"by remote block size {remote_block_size} or vice versa." - ) + """Calculate the block size ratio between local and remote. + + When the local and remote block sizes are not evenly divisible + (e.g. hybrid MLA+GDN models whose MLA component is TP-independent), + returns ``1`` as a safe fallback rather than raising. Downstream + code in the nixl worker handles the non-divisible case via + byte-level ``block_lens``. + """ + if self.block_size % remote_block_size != 0: + return 1 return self.block_size // remote_block_size def is_kv_replicated( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index bef5879bee7e..46a797d393a7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1556,14 +1556,11 @@ def _validate_remote_agent_handshake( assert remote_info.remote_tp_size == remote_tp_size tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size) - try: - block_size_ratio = self.transfer_topo.block_size_ratio( - nixl_agent_meta.block_size - ) - except AssertionError: - # Heterogeneous TP with non-divisible block sizes (e.g. hybrid - # MLA+GDN). Use 1 as a safe fallback for validation checks. - block_size_ratio = 1 + # Heterogeneous TP with non-divisible block sizes (e.g. hybrid + # MLA+GDN) falls back to 1 inside block_size_ratio. + block_size_ratio = self.transfer_topo.block_size_ratio( + nixl_agent_meta.block_size + ) # num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba. # Mamba models can have replicated FA KV with tp_ratio < 0. # MLA models do not need to handle kv replication. @@ -1928,12 +1925,9 @@ def get_finished(self) -> tuple[set[str], set[str]]: # post processing for heteroblocksize remote_info = self.transfer_topo.get_engine_info(meta.remote.engine_id) - try: - block_size_ratio = self.transfer_topo.block_size_ratio( - remote_info.remote_block_size - ) - except AssertionError: - block_size_ratio = 1 + block_size_ratio = self.transfer_topo.block_size_ratio( + remote_info.remote_block_size + ) if not self.use_mla and ( block_size_ratio > 1 or self.enable_permute_local_kv ): @@ -2325,12 +2319,9 @@ def _read_blocks( remote_block_ids = read_spec.remote_block_ids remote_info = self.transfer_topo.get_engine_info(dst_engine_id) - try: - block_size_ratio = self.transfer_topo.block_size_ratio( - remote_info.remote_block_size - ) - except AssertionError: - block_size_ratio = 1 + block_size_ratio = self.transfer_topo.block_size_ratio( + remote_info.remote_block_size + ) if block_size_ratio > 1: # TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups. assert not self._is_hma_required