diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 3da1b533ad73..bda9e43c7829 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -2012,7 +2012,7 @@ def test_transfer_failure_logging( connector = NixlConnector( vllm_config, KVConnectorRole.WORKER, - make_kv_cache_config(block_size=16, hma_enabled=enable_hma), + make_kv_cache_config(block_size=16, swa_enabled=enable_hma), ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index d4b0c28a5de5..898f8e4b35ba 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA.""" +"""Unit tests for NixlConnectorScheduler with HMA and Mamba N-1 prefill.""" from unittest.mock import patch @@ -14,24 +14,26 @@ ) from .utils import ( + create_request, create_vllm_config, make_kv_cache_config, + make_nixl_scheduler, ) @pytest.mark.cpu_test @pytest.mark.parametrize( - "hma_enabled,expected_sw_sizes", + "swa_enabled,expected_sw_sizes", [ - # HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128) + # SWA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128) (True, [0, 128 + 1]), - # HMA disabled: only FullAttentionSpec (0) + # SWA disabled: only FullAttentionSpec (0) (False, [0]), ], ) @patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") -def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes): - """Test sw_sizes is correctly computed based on HMA enabled/disabled.""" +def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes): + """Test sw_sizes is correctly computed based on SWA enabled/disabled.""" from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlConnectorScheduler, ) @@ -42,7 +44,7 @@ def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes): vllm_config = create_vllm_config(block_size=block_size) # SW 2048 tokens=>128 blocks kv_cache_config = make_kv_cache_config( - block_size=block_size, hma_enabled=hma_enabled, sw_size=2048 + block_size=block_size, swa_enabled=swa_enabled, sw_size=2048 ) scheduler = NixlConnectorScheduler( @@ -75,7 +77,7 @@ def test_logical_to_kernel_block_ids_with_hma(): # So each logical block maps to 2 kernel blocks eg [0]->[0,1] worker._physical_blocks_per_logical_kv_block = 2 # FA + SW groups (neither is MambaSpec, so both get expanded) - worker.kv_cache_config = make_kv_cache_config(block_size=16, hma_enabled=True) + worker.kv_cache_config = make_kv_cache_config(block_size=16, swa_enabled=True) # Test conversion: FA + SW group logical_block_ids = [[0, 1, 2], [3, 4]] @@ -313,3 +315,106 @@ def test_nixl_metadata_hybrid_ssm_block_ids(): assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17] assert list(req_meta.remote.block_ids[1]) == [20, 21] assert len(req_meta.remote.block_ids[0]) != len(req_meta.remote.block_ids[1]) + + +# ── Mamba N-1 prefill tests ────────────────────────────────────────────── + + +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "has_mamba,is_hma_required,expected_count", + [ + (True, True, 9), + (False, False, 10), + (False, True, 10), + ], + ids=["mamba", "fa_only", "swa_only"], +) +def test_mamba_n1_d_side(has_mamba, is_hma_required, expected_count): + """D-side: Mamba gets N-1 matched tokens, non-Mamba gets N.""" + sched = make_nixl_scheduler(has_mamba=has_mamba, is_hma_required=is_hma_required) + req = create_request(num_tokens=10, do_remote_prefill=True) + + count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0) + assert count == expected_count + assert is_async is True + + +@pytest.mark.cpu_test +def test_mamba_n1_p_side_truncation(): + """P-side: Mamba truncates prompt to N-1, sets max_tokens=1. + + Also verifies idempotency (calling again is a no-op) which is + needed for preemption safety via the _p_side_truncated guard, + and that non-Mamba models skip truncation entirely. + """ + sched = make_nixl_scheduler(has_mamba=True, is_hma_required=True) + req = create_request(num_tokens=10, do_remote_decode=True) + req.max_tokens = 128 + original_len = len(req.prompt_token_ids) + + count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0) + + assert count == 0 + assert is_async is False + assert len(req.prompt_token_ids) == original_len - 1 + assert req.num_prompt_tokens == original_len - 1 + assert req.max_tokens == 1 + assert req.kv_transfer_params["_p_side_truncated"] is True + + # Idempotency: second call must not truncate further + sched.get_num_new_matched_tokens(req, num_computed_tokens=0) + assert len(req.prompt_token_ids) == original_len - 1 + + # Non-Mamba: truncation is skipped + fa_sched = make_nixl_scheduler(has_mamba=False, is_hma_required=False) + fa_req = create_request(num_tokens=10, do_remote_decode=True) + fa_original = len(fa_req.prompt_token_ids) + + fa_sched.get_num_new_matched_tokens(fa_req, num_computed_tokens=0) + assert len(fa_req.prompt_token_ids) == fa_original + + +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "swa_enabled,mamba_enabled,expected_has_mamba,expected_is_hma", + [ + (True, True, True, True), + (True, False, False, True), + (False, False, False, False), + ], + ids=["fa_swa_mamba", "fa_swa_only", "fa_only"], +) +@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") +def test_has_mamba_init( + mock_platform, + swa_enabled, + mamba_enabled, + expected_has_mamba, + expected_is_hma, +): + """Test _has_mamba / _is_hma_required derived from kv_cache_groups.""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorScheduler, + ) + + mock_platform.device_type = "cpu" + + block_size = 16 + vllm_config = create_vllm_config(block_size=block_size) + # VllmConfig.__post_init__ auto-disables HMA when kv_transfer_config + # is set; override so we can test the scheduler's own derivation. + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False + kv_cache_config = make_kv_cache_config( + block_size=block_size, + swa_enabled=swa_enabled, + mamba_enabled=mamba_enabled, + ) + + scheduler = NixlConnectorScheduler( + vllm_config=vllm_config, + engine_id="test-engine", + kv_cache_config=kv_cache_config, + ) + assert scheduler._has_mamba is expected_has_mamba + assert scheduler._is_hma_required is expected_is_hma diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index f48dc0fff602..283b4f25e6e4 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -1,10 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +from unittest.mock import patch import pytest -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + KVConnectorOutput, + ModelRunnerOutput, +) from vllm.v1.request import FinishReason, RequestStatus from .utils import ( @@ -13,6 +18,7 @@ create_request, create_scheduler, create_vllm_config, + make_kv_cache_config, ) pytestmark = pytest.mark.cpu_test @@ -579,3 +585,73 @@ def test_cannot_recv(): scheduler.update_from_output(scheduler_output, model_runner_output) _ = scheduler.schedule() assert_scheduler_empty(scheduler) + + +@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") +def test_p_side_chunked_prefill_mamba(mock_platform): + """P-side integration: Mamba N-1 truncation + chunked prefill completes. + + A 64-token P-side request is truncated to 63 by the N-1 fix, then + chunked into two prefill steps (32 + 31) and finishes with + LENGTH_CAPPED because max_tokens is set to 1. + """ + mock_platform.device_type = "cpu" + + BATCH_SIZE = 32 + NUM_TOKENS = 64 + BLOCK_SIZE = 16 + + vllm_config = create_vllm_config( + max_num_batched_tokens=BATCH_SIZE, + block_size=BLOCK_SIZE, + ) + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False + + kv_cache_config = make_kv_cache_config( + block_size=BLOCK_SIZE, + mamba_enabled=True, + num_blocks=10000, + ) + + scheduler = create_scheduler(vllm_config, kv_cache_config=kv_cache_config) + + request = create_request( + num_tokens=NUM_TOKENS, + do_remote_decode=True, + block_size=BLOCK_SIZE, + ) + request.max_tokens = 128 + scheduler.add_request(request) + request_id = request.request_id + + # ── Step 1: first chunk ── + scheduler_output = scheduler.schedule() + + assert len(request.prompt_token_ids) == NUM_TOKENS - 1 + assert request.max_tokens == 1 + assert scheduler_output.num_scheduled_tokens[request_id] == BATCH_SIZE + assert request.num_computed_tokens == BATCH_SIZE + + # Model returns no tokens for intermediate prefill chunk + intermediate_output = ModelRunnerOutput( + req_ids=[request.request_id], + req_id_to_index={request.request_id: 0}, + sampled_token_ids=[[]], + ) + scheduler.update_from_output(scheduler_output, intermediate_output) + + # ── Step 2: remaining chunk ── + scheduler_output = scheduler.schedule() + + remaining = NUM_TOKENS - 1 - BATCH_SIZE # 31 + assert scheduler_output.num_scheduled_tokens[request_id] == remaining + assert request.num_computed_tokens == NUM_TOKENS - 1 + + # Prefill complete: model generates 1 decode token + final_output = create_model_runner_output([request]) + engine_core_outputs = scheduler.update_from_output(scheduler_output, final_output) + + # max_tokens=1 → request finishes with LENGTH + outputs = engine_core_outputs[0].outputs + assert len(outputs) == 1 + assert outputs[0].finish_reason == FinishReason.LENGTH diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 6e00cf8d5bed..1e2a05f0e345 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -37,6 +37,7 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, + MambaSpec, SlidingWindowSpec, ) from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -423,7 +424,8 @@ def wait_for_save(self): def make_kv_cache_config( block_size: int, - hma_enabled: bool = False, + swa_enabled: bool = False, + mamba_enabled: bool = False, sw_size: int = 128, num_blocks: int = 100, ) -> KVCacheConfig: @@ -438,7 +440,7 @@ def make_kv_cache_config( ), ) ] - if hma_enabled: + if swa_enabled: kv_cache_groups.append( KVCacheGroupSpec( ["layer1", "layer3"], @@ -451,6 +453,32 @@ def make_kv_cache_config( ), ) ) + if mamba_enabled: + kv_cache_groups.append( + KVCacheGroupSpec( + ["mamba0", "mamba1"], + MambaSpec( + block_size=block_size, + shapes=((16,), (16,)), + dtypes=(torch.float16,), + ), + ) + ) return KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups ) + + +def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False): + """Create a NixlConnectorScheduler via __new__ (skipping __init__). + + Only sets the two flags needed by the N-1 prefill logic. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorScheduler, + ) + + sched = object.__new__(NixlConnectorScheduler) + sched._has_mamba = has_mamba + sched._is_hma_required = is_hma_required + return sched diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 79a04bcb95e0..ed53c35c9ed9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -572,6 +572,10 @@ def __init__( for g in kv_cache_config.kv_cache_groups ) ) + self._has_mamba = any( + isinstance(g.kv_cache_spec, MambaSpec) + for g in kv_cache_config.kv_cache_groups + ) logger.info("Initializing NIXL Scheduler %s", engine_id) if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: @@ -717,6 +721,39 @@ def _nixl_handshake_listener( logger.warning("Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) + def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int: + """D-side only. Returns N-1 for Mamba models since the decoder + always recomputes the last token and must start from h(N-1).""" + if self._has_mamba and num_prompt_tokens > 1: + return num_prompt_tokens - 1 + return num_prompt_tokens + + def _truncate_mamba_request_for_prefill(self, request: "Request") -> None: + """P-side only: drop the last prompt token so the prefiller computes + h(N-1) instead of h(N). The decoder recomputes the last token to + derive h(N) correctly. + + Guarded by ``_p_side_truncated`` to avoid repeated truncation if the + request is preempted and rescheduled.""" + params = request.kv_transfer_params + if ( + params is not None + # Guard against repeated truncation after preemption/reschedule. + and not params.get("_p_side_truncated") + and request.num_prompt_tokens > 1 + ): + if request.prompt_token_ids is not None: + request.prompt_token_ids.pop() + elif request.prompt_embeds is not None: + request.prompt_embeds = request.prompt_embeds[:-1] + else: + return + + request._all_token_ids.pop() + request.num_prompt_tokens -= 1 + request.max_tokens = 1 + params["_p_side_truncated"] = True + def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int, bool]: @@ -746,10 +783,14 @@ def get_num_new_matched_tokens( if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. token_ids = request.prompt_token_ids or [] - count = len(token_ids) - num_computed_tokens + actual = self._mamba_prefill_token_count(len(token_ids)) + count = actual - num_computed_tokens if count > 0: return count, True + if params is not None and params.get("do_remote_decode") and self._has_mamba: + self._truncate_mamba_request_for_prefill(request) + # No remote prefill for this request. return 0, False