Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,7 +2012,7 @@ def test_transfer_failure_logging(
connector = NixlConnector(
vllm_config,
KVConnectorRole.WORKER,
make_kv_cache_config(block_size=16, hma_enabled=enable_hma),
make_kv_cache_config(block_size=16, swa_enabled=enable_hma),
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config,
Expand Down
121 changes: 113 additions & 8 deletions tests/v1/kv_connector/unit/test_nixl_connector_hma.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA."""
"""Unit tests for NixlConnectorScheduler with HMA and Mamba N-1 prefill."""

from unittest.mock import patch

Expand All @@ -14,24 +14,26 @@
)

from .utils import (
create_request,
create_vllm_config,
make_kv_cache_config,
make_nixl_scheduler,
)


@pytest.mark.cpu_test
@pytest.mark.parametrize(
"hma_enabled,expected_sw_sizes",
"swa_enabled,expected_sw_sizes",
[
# HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
# SWA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
(True, [0, 128 + 1]),
# HMA disabled: only FullAttentionSpec (0)
# SWA disabled: only FullAttentionSpec (0)
(False, [0]),
],
)
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
"""Test sw_sizes is correctly computed based on HMA enabled/disabled."""
def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes):
"""Test sw_sizes is correctly computed based on SWA enabled/disabled."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler,
)
Expand All @@ -42,7 +44,7 @@ def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
vllm_config = create_vllm_config(block_size=block_size)
# SW 2048 tokens=>128 blocks
kv_cache_config = make_kv_cache_config(
block_size=block_size, hma_enabled=hma_enabled, sw_size=2048
block_size=block_size, swa_enabled=swa_enabled, sw_size=2048
)

scheduler = NixlConnectorScheduler(
Expand Down Expand Up @@ -75,7 +77,7 @@ def test_logical_to_kernel_block_ids_with_hma():
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker._physical_blocks_per_logical_kv_block = 2
# FA + SW groups (neither is MambaSpec, so both get expanded)
worker.kv_cache_config = make_kv_cache_config(block_size=16, hma_enabled=True)
worker.kv_cache_config = make_kv_cache_config(block_size=16, swa_enabled=True)

# Test conversion: FA + SW group
logical_block_ids = [[0, 1, 2], [3, 4]]
Expand Down Expand Up @@ -313,3 +315,106 @@ def test_nixl_metadata_hybrid_ssm_block_ids():
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
assert list(req_meta.remote.block_ids[1]) == [20, 21]
assert len(req_meta.remote.block_ids[0]) != len(req_meta.remote.block_ids[1])


# ── Mamba N-1 prefill tests ──────────────────────────────────────────────


@pytest.mark.cpu_test
@pytest.mark.parametrize(
"has_mamba,is_hma_required,expected_count",
[
(True, True, 9),
(False, False, 10),
(False, True, 10),
],
ids=["mamba", "fa_only", "swa_only"],
)
def test_mamba_n1_d_side(has_mamba, is_hma_required, expected_count):
"""D-side: Mamba gets N-1 matched tokens, non-Mamba gets N."""
sched = make_nixl_scheduler(has_mamba=has_mamba, is_hma_required=is_hma_required)
req = create_request(num_tokens=10, do_remote_prefill=True)

count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0)
assert count == expected_count
assert is_async is True


@pytest.mark.cpu_test
def test_mamba_n1_p_side_truncation():
"""P-side: Mamba truncates prompt to N-1, sets max_tokens=1.

Also verifies idempotency (calling again is a no-op) which is
needed for preemption safety via the _p_side_truncated guard,
and that non-Mamba models skip truncation entirely.
"""
sched = make_nixl_scheduler(has_mamba=True, is_hma_required=True)
req = create_request(num_tokens=10, do_remote_decode=True)
req.max_tokens = 128
original_len = len(req.prompt_token_ids)

count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0)

assert count == 0
assert is_async is False
assert len(req.prompt_token_ids) == original_len - 1
assert req.num_prompt_tokens == original_len - 1
assert req.max_tokens == 1
assert req.kv_transfer_params["_p_side_truncated"] is True

# Idempotency: second call must not truncate further
sched.get_num_new_matched_tokens(req, num_computed_tokens=0)
assert len(req.prompt_token_ids) == original_len - 1

# Non-Mamba: truncation is skipped
fa_sched = make_nixl_scheduler(has_mamba=False, is_hma_required=False)
fa_req = create_request(num_tokens=10, do_remote_decode=True)
fa_original = len(fa_req.prompt_token_ids)

fa_sched.get_num_new_matched_tokens(fa_req, num_computed_tokens=0)
assert len(fa_req.prompt_token_ids) == fa_original


@pytest.mark.cpu_test
@pytest.mark.parametrize(
"swa_enabled,mamba_enabled,expected_has_mamba,expected_is_hma",
[
(True, True, True, True),
(True, False, False, True),
(False, False, False, False),
],
ids=["fa_swa_mamba", "fa_swa_only", "fa_only"],
)
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
def test_has_mamba_init(
mock_platform,
swa_enabled,
mamba_enabled,
expected_has_mamba,
expected_is_hma,
):
"""Test _has_mamba / _is_hma_required derived from kv_cache_groups."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler,
)

mock_platform.device_type = "cpu"

block_size = 16
vllm_config = create_vllm_config(block_size=block_size)
# VllmConfig.__post_init__ auto-disables HMA when kv_transfer_config
# is set; override so we can test the scheduler's own derivation.
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
kv_cache_config = make_kv_cache_config(
block_size=block_size,
swa_enabled=swa_enabled,
mamba_enabled=mamba_enabled,
)

scheduler = NixlConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
assert scheduler._has_mamba is expected_has_mamba
assert scheduler._is_hma_required is expected_is_hma
78 changes: 77 additions & 1 deletion tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from unittest.mock import patch

import pytest

from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
ModelRunnerOutput,
)
from vllm.v1.request import FinishReason, RequestStatus

from .utils import (
Expand All @@ -13,6 +18,7 @@
create_request,
create_scheduler,
create_vllm_config,
make_kv_cache_config,
)

pytestmark = pytest.mark.cpu_test
Expand Down Expand Up @@ -579,3 +585,73 @@ def test_cannot_recv():
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)


@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
def test_p_side_chunked_prefill_mamba(mock_platform):
"""P-side integration: Mamba N-1 truncation + chunked prefill completes.

A 64-token P-side request is truncated to 63 by the N-1 fix, then
chunked into two prefill steps (32 + 31) and finishes with
LENGTH_CAPPED because max_tokens is set to 1.
"""
mock_platform.device_type = "cpu"

BATCH_SIZE = 32
NUM_TOKENS = 64
BLOCK_SIZE = 16

vllm_config = create_vllm_config(
max_num_batched_tokens=BATCH_SIZE,
block_size=BLOCK_SIZE,
)
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False

kv_cache_config = make_kv_cache_config(
block_size=BLOCK_SIZE,
mamba_enabled=True,
num_blocks=10000,
)

scheduler = create_scheduler(vllm_config, kv_cache_config=kv_cache_config)

request = create_request(
num_tokens=NUM_TOKENS,
do_remote_decode=True,
block_size=BLOCK_SIZE,
)
request.max_tokens = 128
scheduler.add_request(request)
request_id = request.request_id

# ── Step 1: first chunk ──
scheduler_output = scheduler.schedule()

assert len(request.prompt_token_ids) == NUM_TOKENS - 1
assert request.max_tokens == 1
assert scheduler_output.num_scheduled_tokens[request_id] == BATCH_SIZE
assert request.num_computed_tokens == BATCH_SIZE

# Model returns no tokens for intermediate prefill chunk
intermediate_output = ModelRunnerOutput(
req_ids=[request.request_id],
req_id_to_index={request.request_id: 0},
sampled_token_ids=[[]],
)
scheduler.update_from_output(scheduler_output, intermediate_output)

# ── Step 2: remaining chunk ──
scheduler_output = scheduler.schedule()

remaining = NUM_TOKENS - 1 - BATCH_SIZE # 31
assert scheduler_output.num_scheduled_tokens[request_id] == remaining
assert request.num_computed_tokens == NUM_TOKENS - 1

# Prefill complete: model generates 1 decode token
final_output = create_model_runner_output([request])
engine_core_outputs = scheduler.update_from_output(scheduler_output, final_output)

# max_tokens=1 → request finishes with LENGTH
outputs = engine_core_outputs[0].outputs
assert len(outputs) == 1
assert outputs[0].finish_reason == FinishReason.LENGTH
32 changes: 30 additions & 2 deletions tests/v1/kv_connector/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
MambaSpec,
SlidingWindowSpec,
)
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
Expand Down Expand Up @@ -423,7 +424,8 @@ def wait_for_save(self):

def make_kv_cache_config(
block_size: int,
hma_enabled: bool = False,
swa_enabled: bool = False,
mamba_enabled: bool = False,
sw_size: int = 128,
num_blocks: int = 100,
) -> KVCacheConfig:
Expand All @@ -438,7 +440,7 @@ def make_kv_cache_config(
),
)
]
if hma_enabled:
if swa_enabled:
kv_cache_groups.append(
KVCacheGroupSpec(
["layer1", "layer3"],
Expand All @@ -451,6 +453,32 @@ def make_kv_cache_config(
),
)
)
if mamba_enabled:
kv_cache_groups.append(
KVCacheGroupSpec(
["mamba0", "mamba1"],
MambaSpec(
block_size=block_size,
shapes=((16,), (16,)),
dtypes=(torch.float16,),
),
)
)
return KVCacheConfig(
num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
)


def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False):
"""Create a NixlConnectorScheduler via __new__ (skipping __init__).

Only sets the two flags needed by the N-1 prefill logic.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler,
)

sched = object.__new__(NixlConnectorScheduler)
sched._has_mamba = has_mamba
sched._is_hma_required = is_hma_required
return sched
43 changes: 42 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,10 @@ def __init__(
for g in kv_cache_config.kv_cache_groups
)
)
self._has_mamba = any(
isinstance(g.kv_cache_spec, MambaSpec)
for g in kv_cache_config.kv_cache_groups
)

logger.info("Initializing NIXL Scheduler %s", engine_id)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
Expand Down Expand Up @@ -717,6 +721,39 @@ def _nixl_handshake_listener(
logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))

def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int:
"""D-side only. Returns N-1 for Mamba models since the decoder
always recomputes the last token and must start from h(N-1)."""
if self._has_mamba and num_prompt_tokens > 1:
return num_prompt_tokens - 1
return num_prompt_tokens

def _truncate_mamba_request_for_prefill(self, request: "Request") -> None:
"""P-side only: drop the last prompt token so the prefiller computes
h(N-1) instead of h(N). The decoder recomputes the last token to
derive h(N) correctly.

Guarded by ``_p_side_truncated`` to avoid repeated truncation if the
request is preempted and rescheduled."""
params = request.kv_transfer_params
if (
params is not None
# Guard against repeated truncation after preemption/reschedule.
and not params.get("_p_side_truncated")
and request.num_prompt_tokens > 1
):
if request.prompt_token_ids is not None:
request.prompt_token_ids.pop()
elif request.prompt_embeds is not None:
request.prompt_embeds = request.prompt_embeds[:-1]
else:
return

request._all_token_ids.pop()
request.num_prompt_tokens -= 1
request.max_tokens = 1
params["_p_side_truncated"] = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have this param?
D is not supposed to get here because of do_remote_decode guard.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gemini-code-assist made the suggestion here: ZhanqiuHu#4 (comment).

Screenshot 2026-03-17 at 2 53 26 PM

I think this makes sure we don't -1 multiple times if the request got rescheduled. Although it doesn't necessarily needs to be inside the params.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhanqiuHu I don't think it can happen as described here, but this is very much valid for preemptions. Let's add a comment

Comment on lines +739 to +755
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This truncation logic only considers prompt_token_ids and will not apply the N-1 prefill fix for Mamba models when prompt_embeds are used. This could lead to state corruption and incorrect outputs in that scenario. The logic should be updated to handle prompt_embeds as well.

        if (
            params is not None
            and not params.get("_p_side_truncated")
            and request.num_prompt_tokens > 1
        ):
            if request.prompt_token_ids is not None:
                request.prompt_token_ids.pop()
            elif request.prompt_embeds is not None:
                request.prompt_embeds = request.prompt_embeds[:-1]
            else:
                # This case should not be possible if num_prompt_tokens > 1.
                return

            request._all_token_ids.pop()
            request.num_prompt_tokens -= 1
            request.max_tokens = 1
            params["_p_side_truncated"] = True


def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
Expand Down Expand Up @@ -746,10 +783,14 @@ def get_num_new_matched_tokens(
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
token_ids = request.prompt_token_ids or []
count = len(token_ids) - num_computed_tokens
actual = self._mamba_prefill_token_count(len(token_ids))
count = actual - num_computed_tokens
if count > 0:
return count, True

if params is not None and params.get("do_remote_decode") and self._has_mamba:
self._truncate_mamba_request_for_prefill(request)

# No remote prefill for this request.
return 0, False

Expand Down
Loading