From 6e4914008affd368878664e7f507bedf1d3f93eb Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Tue, 17 Mar 2026 15:51:43 +0000 Subject: [PATCH 1/8] [SSM/Mamba] N-1 prefill for P/D disaggregation For HMA (Mamba/SSM) models in P/D disaggregation, the prefiller must transfer h(N-1) instead of h(N) so the decoder can correctly recompute the last prompt token. D-side: _hma_prefill_token_count() helper returns N-1 for HMA models, used in get_num_new_matched_tokens so the decoder naturally recomputes the last token from h(N-1). P-side: _truncate_hma_request_for_prefill() truncates prompt to N-1 tokens and sets max_tokens=1. The model computes h(N-1), samples one spurious token (which does NOT update Mamba state), then check_stop fires FINISHED_LENGTH_CAPPED triggering the KV transfer. The P-side truncation is guarded by params["_p_side_truncated"] for idempotency across preemption / re-scheduling cycles. Signed-off-by: ZhanqiuHu --- .../kv_connector/v1/nixl_connector.py | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 79a04bcb95e0..fa873d9c5fb4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -717,6 +717,40 @@ def _nixl_handshake_listener( logger.warning("Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) + def _hma_prefill_token_count(self, num_prompt_tokens: int) -> int: + """For HMA/Mamba models, returns N-1 to avoid transferring + state that already includes the last prompt token's contribution. + The decoder will recompute the last token to derive h(N).""" + if self._is_hma_required and num_prompt_tokens > 1: + return num_prompt_tokens - 1 + return num_prompt_tokens + + def _truncate_hma_request_for_prefill( + self, request: "Request" + ) -> None: + """Truncate P-side prompt to N-1 tokens for HMA/Mamba models. + + On the prefill worker (do_remote_decode), we remove the last + prompt token so the model computes h(N-1). Setting max_tokens=1 + causes the request to finish after one decode sample (which does + NOT update Mamba state), triggering the KV transfer of h(N-1). + + Guarded by params["_p_side_truncated"] for idempotency across + preemption / re-scheduling cycles. + """ + params = request.kv_transfer_params + if ( + params is not None + and params.get("do_remote_decode") + and not params.get("_p_side_truncated") + ): + if request.prompt_token_ids and len(request.prompt_token_ids) > 1: + request.prompt_token_ids.pop() + request._all_token_ids.pop() + request.num_prompt_tokens -= 1 + request.max_tokens = 1 + params["_p_side_truncated"] = True + def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int, bool]: @@ -746,10 +780,14 @@ def get_num_new_matched_tokens( if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. token_ids = request.prompt_token_ids or [] - count = len(token_ids) - num_computed_tokens + actual = self._hma_prefill_token_count(len(token_ids)) + count = actual - num_computed_tokens if count > 0: return count, True + if self._is_hma_required: + self._truncate_hma_request_for_prefill(request) + # No remote prefill for this request. return 0, False From 172e7e684017ab9c071ac0a2900392aebc5f4beb Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Tue, 17 Mar 2026 16:05:32 +0000 Subject: [PATCH 2/8] Refactor: cleaner call-site guards and docstrings - Extract P-side truncation into _truncate_hma_request_for_prefill() - Add _hma_prefill_token_count() helper for D-side N-1 calculation - Explicit do_remote_decode / do_remote_prefill guards at call site - Tighten docstrings with D-side/P-side context Signed-off-by: ZhanqiuHu --- .../kv_connector/v1/nixl_connector.py | 29 +++++++------------ 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index fa873d9c5fb4..26f491e4d310 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -718,9 +718,8 @@ def _nixl_handshake_listener( sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) def _hma_prefill_token_count(self, num_prompt_tokens: int) -> int: - """For HMA/Mamba models, returns N-1 to avoid transferring - state that already includes the last prompt token's contribution. - The decoder will recompute the last token to derive h(N).""" + """D-side only. Returns N-1 for HMA models since the decoder + always recomputes the last token and must start from h(N-1).""" if self._is_hma_required and num_prompt_tokens > 1: return num_prompt_tokens - 1 return num_prompt_tokens @@ -728,22 +727,13 @@ def _hma_prefill_token_count(self, num_prompt_tokens: int) -> int: def _truncate_hma_request_for_prefill( self, request: "Request" ) -> None: - """Truncate P-side prompt to N-1 tokens for HMA/Mamba models. + """For P-side HMA requests, drop the last prompt token so the + prefiller computes h(N-1) instead of h(N). The decoder will + recompute the last token to derive h(N) correctly. - On the prefill worker (do_remote_decode), we remove the last - prompt token so the model computes h(N-1). Setting max_tokens=1 - causes the request to finish after one decode sample (which does - NOT update Mamba state), triggering the KV transfer of h(N-1). - - Guarded by params["_p_side_truncated"] for idempotency across - preemption / re-scheduling cycles. - """ + Idempotent: skips if already truncated.""" params = request.kv_transfer_params - if ( - params is not None - and params.get("do_remote_decode") - and not params.get("_p_side_truncated") - ): + if params is not None and not params.get("_p_side_truncated"): if request.prompt_token_ids and len(request.prompt_token_ids) > 1: request.prompt_token_ids.pop() request._all_token_ids.pop() @@ -785,8 +775,9 @@ def get_num_new_matched_tokens( if count > 0: return count, True - if self._is_hma_required: - self._truncate_hma_request_for_prefill(request) + if params is not None and params.get("do_remote_decode"): + if self._is_hma_required: + self._truncate_hma_request_for_prefill(request) # No remote prefill for this request. return 0, False From f1c4cb69ded02ce186aefa62a629f75c67f92f16 Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Tue, 17 Mar 2026 18:00:05 +0000 Subject: [PATCH 3/8] Fix ruff SIM102: collapse nested if statements Signed-off-by: ZhanqiuHu --- .../kv_connector/v1/nixl_connector.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 26f491e4d310..e76198e7b3f4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -724,22 +724,24 @@ def _hma_prefill_token_count(self, num_prompt_tokens: int) -> int: return num_prompt_tokens - 1 return num_prompt_tokens - def _truncate_hma_request_for_prefill( - self, request: "Request" - ) -> None: + def _truncate_hma_request_for_prefill(self, request: "Request") -> None: """For P-side HMA requests, drop the last prompt token so the prefiller computes h(N-1) instead of h(N). The decoder will recompute the last token to derive h(N) correctly. Idempotent: skips if already truncated.""" params = request.kv_transfer_params - if params is not None and not params.get("_p_side_truncated"): - if request.prompt_token_ids and len(request.prompt_token_ids) > 1: - request.prompt_token_ids.pop() - request._all_token_ids.pop() - request.num_prompt_tokens -= 1 - request.max_tokens = 1 - params["_p_side_truncated"] = True + if ( + params is not None + and not params.get("_p_side_truncated") + and request.prompt_token_ids + and len(request.prompt_token_ids) > 1 + ): + request.prompt_token_ids.pop() + request._all_token_ids.pop() + request.num_prompt_tokens -= 1 + request.max_tokens = 1 + params["_p_side_truncated"] = True def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int @@ -775,9 +777,12 @@ def get_num_new_matched_tokens( if count > 0: return count, True - if params is not None and params.get("do_remote_decode"): - if self._is_hma_required: - self._truncate_hma_request_for_prefill(request) + if ( + params is not None + and params.get("do_remote_decode") + and self._is_hma_required + ): + self._truncate_hma_request_for_prefill(request) # No remote prefill for this request. return 0, False From 2130daf1f1cce256ddc0393e786f7c4e526d1ebc Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Tue, 17 Mar 2026 18:25:24 +0000 Subject: [PATCH 4/8] Use _has_mamba instead of _is_hma_required for N-1 logic _is_hma_required is True for any non-FullAttention model (including SWA), but the N-1 prefill fix only applies to models with cumulative Mamba state. SWA KV is stateless and doesn't need N-1 treatment. Signed-off-by: ZhanqiuHu --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e76198e7b3f4..f274671cff8d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -572,6 +572,10 @@ def __init__( for g in kv_cache_config.kv_cache_groups ) ) + self._has_mamba = any( + isinstance(g.kv_cache_spec, MambaSpec) + for g in kv_cache_config.kv_cache_groups + ) logger.info("Initializing NIXL Scheduler %s", engine_id) if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: @@ -718,9 +722,9 @@ def _nixl_handshake_listener( sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) def _hma_prefill_token_count(self, num_prompt_tokens: int) -> int: - """D-side only. Returns N-1 for HMA models since the decoder + """D-side only. Returns N-1 for Mamba models since the decoder always recomputes the last token and must start from h(N-1).""" - if self._is_hma_required and num_prompt_tokens > 1: + if self._has_mamba and num_prompt_tokens > 1: return num_prompt_tokens - 1 return num_prompt_tokens @@ -777,11 +781,7 @@ def get_num_new_matched_tokens( if count > 0: return count, True - if ( - params is not None - and params.get("do_remote_decode") - and self._is_hma_required - ): + if params is not None and params.get("do_remote_decode") and self._has_mamba: self._truncate_hma_request_for_prefill(request) # No remote prefill for this request. From 3a18e8e64eee708830c8470a92562fe44b8e95e8 Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Tue, 17 Mar 2026 19:29:47 +0000 Subject: [PATCH 5/8] Rename _hma_ helpers to _mamba_ for clarity Signed-off-by: ZhanqiuHu --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f274671cff8d..dd9bd00de3fb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -721,14 +721,14 @@ def _nixl_handshake_listener( logger.warning("Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) - def _hma_prefill_token_count(self, num_prompt_tokens: int) -> int: + def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int: """D-side only. Returns N-1 for Mamba models since the decoder always recomputes the last token and must start from h(N-1).""" if self._has_mamba and num_prompt_tokens > 1: return num_prompt_tokens - 1 return num_prompt_tokens - def _truncate_hma_request_for_prefill(self, request: "Request") -> None: + def _truncate_mamba_request_for_prefill(self, request: "Request") -> None: """For P-side HMA requests, drop the last prompt token so the prefiller computes h(N-1) instead of h(N). The decoder will recompute the last token to derive h(N) correctly. @@ -776,13 +776,13 @@ def get_num_new_matched_tokens( if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. token_ids = request.prompt_token_ids or [] - actual = self._hma_prefill_token_count(len(token_ids)) + actual = self._mamba_prefill_token_count(len(token_ids)) count = actual - num_computed_tokens if count > 0: return count, True if params is not None and params.get("do_remote_decode") and self._has_mamba: - self._truncate_hma_request_for_prefill(request) + self._truncate_mamba_request_for_prefill(request) # No remote prefill for this request. return 0, False From f29f1b6c06cff3265ce7ce8f975773d3e23e0542 Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Tue, 17 Mar 2026 19:38:16 +0000 Subject: [PATCH 6/8] Add comment explaining _p_side_truncated preemption guard Signed-off-by: ZhanqiuHu --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index dd9bd00de3fb..3160c42ca990 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -729,14 +729,16 @@ def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int: return num_prompt_tokens def _truncate_mamba_request_for_prefill(self, request: "Request") -> None: - """For P-side HMA requests, drop the last prompt token so the - prefiller computes h(N-1) instead of h(N). The decoder will - recompute the last token to derive h(N) correctly. + """P-side only: drop the last prompt token so the prefiller computes + h(N-1) instead of h(N). The decoder recomputes the last token to + derive h(N) correctly. - Idempotent: skips if already truncated.""" + Guarded by ``_p_side_truncated`` to avoid repeated truncation if the + request is preempted and rescheduled.""" params = request.kv_transfer_params if ( params is not None + # Guard against repeated truncation after preemption/reschedule. and not params.get("_p_side_truncated") and request.prompt_token_ids and len(request.prompt_token_ids) > 1 From 458ca60fc3741e5e58c155f70a1f36a659220c84 Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Wed, 18 Mar 2026 14:26:59 +0000 Subject: [PATCH 7/8] add test cases Signed-off-by: ZhanqiuHu --- .../kv_connector/unit/test_nixl_connector.py | 2 +- .../unit/test_nixl_connector_hma.py | 121 ++++++++++++++++-- .../unit/test_remote_prefill_lifecycle.py | 78 ++++++++++- tests/v1/kv_connector/unit/utils.py | 32 ++++- 4 files changed, 221 insertions(+), 12 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 3da1b533ad73..bda9e43c7829 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -2012,7 +2012,7 @@ def test_transfer_failure_logging( connector = NixlConnector( vllm_config, KVConnectorRole.WORKER, - make_kv_cache_config(block_size=16, hma_enabled=enable_hma), + make_kv_cache_config(block_size=16, swa_enabled=enable_hma), ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index d4b0c28a5de5..898f8e4b35ba 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA.""" +"""Unit tests for NixlConnectorScheduler with HMA and Mamba N-1 prefill.""" from unittest.mock import patch @@ -14,24 +14,26 @@ ) from .utils import ( + create_request, create_vllm_config, make_kv_cache_config, + make_nixl_scheduler, ) @pytest.mark.cpu_test @pytest.mark.parametrize( - "hma_enabled,expected_sw_sizes", + "swa_enabled,expected_sw_sizes", [ - # HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128) + # SWA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128) (True, [0, 128 + 1]), - # HMA disabled: only FullAttentionSpec (0) + # SWA disabled: only FullAttentionSpec (0) (False, [0]), ], ) @patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") -def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes): - """Test sw_sizes is correctly computed based on HMA enabled/disabled.""" +def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes): + """Test sw_sizes is correctly computed based on SWA enabled/disabled.""" from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlConnectorScheduler, ) @@ -42,7 +44,7 @@ def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes): vllm_config = create_vllm_config(block_size=block_size) # SW 2048 tokens=>128 blocks kv_cache_config = make_kv_cache_config( - block_size=block_size, hma_enabled=hma_enabled, sw_size=2048 + block_size=block_size, swa_enabled=swa_enabled, sw_size=2048 ) scheduler = NixlConnectorScheduler( @@ -75,7 +77,7 @@ def test_logical_to_kernel_block_ids_with_hma(): # So each logical block maps to 2 kernel blocks eg [0]->[0,1] worker._physical_blocks_per_logical_kv_block = 2 # FA + SW groups (neither is MambaSpec, so both get expanded) - worker.kv_cache_config = make_kv_cache_config(block_size=16, hma_enabled=True) + worker.kv_cache_config = make_kv_cache_config(block_size=16, swa_enabled=True) # Test conversion: FA + SW group logical_block_ids = [[0, 1, 2], [3, 4]] @@ -313,3 +315,106 @@ def test_nixl_metadata_hybrid_ssm_block_ids(): assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17] assert list(req_meta.remote.block_ids[1]) == [20, 21] assert len(req_meta.remote.block_ids[0]) != len(req_meta.remote.block_ids[1]) + + +# ── Mamba N-1 prefill tests ────────────────────────────────────────────── + + +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "has_mamba,is_hma_required,expected_count", + [ + (True, True, 9), + (False, False, 10), + (False, True, 10), + ], + ids=["mamba", "fa_only", "swa_only"], +) +def test_mamba_n1_d_side(has_mamba, is_hma_required, expected_count): + """D-side: Mamba gets N-1 matched tokens, non-Mamba gets N.""" + sched = make_nixl_scheduler(has_mamba=has_mamba, is_hma_required=is_hma_required) + req = create_request(num_tokens=10, do_remote_prefill=True) + + count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0) + assert count == expected_count + assert is_async is True + + +@pytest.mark.cpu_test +def test_mamba_n1_p_side_truncation(): + """P-side: Mamba truncates prompt to N-1, sets max_tokens=1. + + Also verifies idempotency (calling again is a no-op) which is + needed for preemption safety via the _p_side_truncated guard, + and that non-Mamba models skip truncation entirely. + """ + sched = make_nixl_scheduler(has_mamba=True, is_hma_required=True) + req = create_request(num_tokens=10, do_remote_decode=True) + req.max_tokens = 128 + original_len = len(req.prompt_token_ids) + + count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0) + + assert count == 0 + assert is_async is False + assert len(req.prompt_token_ids) == original_len - 1 + assert req.num_prompt_tokens == original_len - 1 + assert req.max_tokens == 1 + assert req.kv_transfer_params["_p_side_truncated"] is True + + # Idempotency: second call must not truncate further + sched.get_num_new_matched_tokens(req, num_computed_tokens=0) + assert len(req.prompt_token_ids) == original_len - 1 + + # Non-Mamba: truncation is skipped + fa_sched = make_nixl_scheduler(has_mamba=False, is_hma_required=False) + fa_req = create_request(num_tokens=10, do_remote_decode=True) + fa_original = len(fa_req.prompt_token_ids) + + fa_sched.get_num_new_matched_tokens(fa_req, num_computed_tokens=0) + assert len(fa_req.prompt_token_ids) == fa_original + + +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "swa_enabled,mamba_enabled,expected_has_mamba,expected_is_hma", + [ + (True, True, True, True), + (True, False, False, True), + (False, False, False, False), + ], + ids=["fa_swa_mamba", "fa_swa_only", "fa_only"], +) +@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") +def test_has_mamba_init( + mock_platform, + swa_enabled, + mamba_enabled, + expected_has_mamba, + expected_is_hma, +): + """Test _has_mamba / _is_hma_required derived from kv_cache_groups.""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorScheduler, + ) + + mock_platform.device_type = "cpu" + + block_size = 16 + vllm_config = create_vllm_config(block_size=block_size) + # VllmConfig.__post_init__ auto-disables HMA when kv_transfer_config + # is set; override so we can test the scheduler's own derivation. + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False + kv_cache_config = make_kv_cache_config( + block_size=block_size, + swa_enabled=swa_enabled, + mamba_enabled=mamba_enabled, + ) + + scheduler = NixlConnectorScheduler( + vllm_config=vllm_config, + engine_id="test-engine", + kv_cache_config=kv_cache_config, + ) + assert scheduler._has_mamba is expected_has_mamba + assert scheduler._is_hma_required is expected_is_hma diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index f48dc0fff602..283b4f25e6e4 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -1,10 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +from unittest.mock import patch import pytest -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + KVConnectorOutput, + ModelRunnerOutput, +) from vllm.v1.request import FinishReason, RequestStatus from .utils import ( @@ -13,6 +18,7 @@ create_request, create_scheduler, create_vllm_config, + make_kv_cache_config, ) pytestmark = pytest.mark.cpu_test @@ -579,3 +585,73 @@ def test_cannot_recv(): scheduler.update_from_output(scheduler_output, model_runner_output) _ = scheduler.schedule() assert_scheduler_empty(scheduler) + + +@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") +def test_p_side_chunked_prefill_mamba(mock_platform): + """P-side integration: Mamba N-1 truncation + chunked prefill completes. + + A 64-token P-side request is truncated to 63 by the N-1 fix, then + chunked into two prefill steps (32 + 31) and finishes with + LENGTH_CAPPED because max_tokens is set to 1. + """ + mock_platform.device_type = "cpu" + + BATCH_SIZE = 32 + NUM_TOKENS = 64 + BLOCK_SIZE = 16 + + vllm_config = create_vllm_config( + max_num_batched_tokens=BATCH_SIZE, + block_size=BLOCK_SIZE, + ) + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False + + kv_cache_config = make_kv_cache_config( + block_size=BLOCK_SIZE, + mamba_enabled=True, + num_blocks=10000, + ) + + scheduler = create_scheduler(vllm_config, kv_cache_config=kv_cache_config) + + request = create_request( + num_tokens=NUM_TOKENS, + do_remote_decode=True, + block_size=BLOCK_SIZE, + ) + request.max_tokens = 128 + scheduler.add_request(request) + request_id = request.request_id + + # ── Step 1: first chunk ── + scheduler_output = scheduler.schedule() + + assert len(request.prompt_token_ids) == NUM_TOKENS - 1 + assert request.max_tokens == 1 + assert scheduler_output.num_scheduled_tokens[request_id] == BATCH_SIZE + assert request.num_computed_tokens == BATCH_SIZE + + # Model returns no tokens for intermediate prefill chunk + intermediate_output = ModelRunnerOutput( + req_ids=[request.request_id], + req_id_to_index={request.request_id: 0}, + sampled_token_ids=[[]], + ) + scheduler.update_from_output(scheduler_output, intermediate_output) + + # ── Step 2: remaining chunk ── + scheduler_output = scheduler.schedule() + + remaining = NUM_TOKENS - 1 - BATCH_SIZE # 31 + assert scheduler_output.num_scheduled_tokens[request_id] == remaining + assert request.num_computed_tokens == NUM_TOKENS - 1 + + # Prefill complete: model generates 1 decode token + final_output = create_model_runner_output([request]) + engine_core_outputs = scheduler.update_from_output(scheduler_output, final_output) + + # max_tokens=1 → request finishes with LENGTH + outputs = engine_core_outputs[0].outputs + assert len(outputs) == 1 + assert outputs[0].finish_reason == FinishReason.LENGTH diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 6e00cf8d5bed..1e2a05f0e345 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -37,6 +37,7 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, + MambaSpec, SlidingWindowSpec, ) from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -423,7 +424,8 @@ def wait_for_save(self): def make_kv_cache_config( block_size: int, - hma_enabled: bool = False, + swa_enabled: bool = False, + mamba_enabled: bool = False, sw_size: int = 128, num_blocks: int = 100, ) -> KVCacheConfig: @@ -438,7 +440,7 @@ def make_kv_cache_config( ), ) ] - if hma_enabled: + if swa_enabled: kv_cache_groups.append( KVCacheGroupSpec( ["layer1", "layer3"], @@ -451,6 +453,32 @@ def make_kv_cache_config( ), ) ) + if mamba_enabled: + kv_cache_groups.append( + KVCacheGroupSpec( + ["mamba0", "mamba1"], + MambaSpec( + block_size=block_size, + shapes=((16,), (16,)), + dtypes=(torch.float16,), + ), + ) + ) return KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups ) + + +def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False): + """Create a NixlConnectorScheduler via __new__ (skipping __init__). + + Only sets the two flags needed by the N-1 prefill logic. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorScheduler, + ) + + sched = object.__new__(NixlConnectorScheduler) + sched._has_mamba = has_mamba + sched._is_hma_required = is_hma_required + return sched From 17f996fbad27b924847cc7128126756d4d252124 Mon Sep 17 00:00:00 2001 From: ZhanqiuHu Date: Wed, 18 Mar 2026 17:45:13 +0000 Subject: [PATCH 8/8] handle prompt embeddings Signed-off-by: ZhanqiuHu --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 3160c42ca990..ed53c35c9ed9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -740,10 +740,15 @@ def _truncate_mamba_request_for_prefill(self, request: "Request") -> None: params is not None # Guard against repeated truncation after preemption/reschedule. and not params.get("_p_side_truncated") - and request.prompt_token_ids - and len(request.prompt_token_ids) > 1 + and request.num_prompt_tokens > 1 ): - request.prompt_token_ids.pop() + if request.prompt_token_ids is not None: + request.prompt_token_ids.pop() + elif request.prompt_embeds is not None: + request.prompt_embeds = request.prompt_embeds[:-1] + else: + return + request._all_token_ids.pop() request.num_prompt_tokens -= 1 request.max_tokens = 1