Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def test_spec_decode_acceptance_length():
max_tokens=DEFAULT_OUTPUT_LEN,
temperature=0.0,
top_p=1.0,
# Prompts are already chat-templated (contain BOS); avoid the
# completions API prepending a second BOS, which would lower
# acceptance ~5% vs the add_special_tokens=False standalone baselines.
extra_body={"add_special_tokens": False},
)
if i < 3:
text = resp.choices[0].text.strip()[:100]
Expand Down
128 changes: 128 additions & 0 deletions tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
import copy

import pytest
import torch

from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import (
NixlConnectorScheduler,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
MambaSpec,
)
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import FinishReason, RequestStatus

Expand Down Expand Up @@ -222,6 +232,124 @@ def test_prefix_cache_lifecycle():
assert_scheduler_empty(scheduler)


def _make_nixl_connector_scheduler(
vllm_config, kv_cache_groups=None
) -> NixlConnectorScheduler:
"""Build a standalone NIXL connector-scheduler for directly exercising
``request_finished``. Defaults to a single full-attention KV cache group."""
if kv_cache_groups is None:
block_size = vllm_config.cache_config.block_size
kv_cache_groups = [
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
)
]
kv_cache_config = KVCacheConfig(
num_blocks=10000,
kv_cache_tensors=[],
kv_cache_groups=kv_cache_groups,
)
return NixlConnectorScheduler(vllm_config, "test-engine-id", kv_cache_config)


@pytest.mark.parametrize("extra_lookahead_blocks", [0, 1, 2])
def test_remote_decode_drops_lookahead_blocks(extra_lookahead_blocks):
"""Regression test: request_finished must transfer exactly the blocks
holding the computed KV, not the spec-decode lookahead reservation blocks
allocated past num_computed_tokens. Sending an extra block lets the decode
node's suffix-trim misalign the sequence and read stale KV.
"""
vllm_config = create_vllm_config()
connector = _make_nixl_connector_scheduler(vllm_config)

block_size = vllm_config.cache_config.block_size
# Multiple of block_size: the worst case where the lookahead slot needs a
# brand-new block. Allocate prompt blocks + the lookahead reservation.
num_computed_tokens = 4 * block_size
num_prompt_blocks = num_computed_tokens // block_size # == 4
allocated_block_ids = list(range(1, num_prompt_blocks + extra_lookahead_blocks + 1))

request = create_request(
request_id=1,
block_size=block_size,
num_tokens=num_computed_tokens,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
request.num_computed_tokens = num_computed_tokens

delay_free_blocks, params = connector.request_finished(
request, (allocated_block_ids,)
)

assert delay_free_blocks is True
assert params is not None
# Trailing lookahead blocks dropped, regardless of how many were allocated.
assert params["remote_block_ids"] == ([1, 2, 3, 4],)
assert len(params["remote_block_ids"][0]) == num_prompt_blocks
assert params["remote_num_tokens"] == num_computed_tokens


def test_remote_decode_lookahead_clip_is_per_group():
"""Clipping is per-group with each group's own block_size: in a hybrid
model the attention group is clipped while a Mamba/SSM state group is left
untouched. The attention group uses a block_size != the global one, so a
global-block_size implementation would clip it incorrectly.
"""
vllm_config = create_vllm_config()
global_block_size = vllm_config.cache_config.block_size # 16
attn_block_size = 2 * global_block_size # 32

kv_cache_groups = [
KVCacheGroupSpec(
["mamba_layer"],
MambaSpec(
block_size=global_block_size,
shapes=((1,),),
dtypes=(torch.float32,),
),
),
KVCacheGroupSpec(
["attn_layer"],
FullAttentionSpec(
block_size=attn_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
]
connector = _make_nixl_connector_scheduler(vllm_config, kv_cache_groups)

# 64 tokens => 2 attn blocks at block_size 32, + 1 lookahead block.
# (cdiv(64, 16) == 4 would not clip, so this fails with the global size.)
num_computed_tokens = 2 * attn_block_size # 64
request = create_request(
request_id=1,
block_size=attn_block_size,
num_tokens=num_computed_tokens,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
request.num_computed_tokens = num_computed_tokens

# group 0: Mamba state block; group 1: 2 prompt blocks + 1 lookahead block.
mamba_block_ids = [101]
attn_block_ids = [1, 2, 3]
_, params = connector.request_finished(request, (mamba_block_ids, attn_block_ids))

# Mamba group passed through; attention group clipped at its own block_size.
remote_block_ids = params["remote_block_ids"]
assert remote_block_ids[0] == [101]
assert remote_block_ids[1] == [1, 2]


def test_abort_during_kv_transfer():
"""Test aborting request does not release blocks for remote decode."""

Expand Down
21 changes: 21 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,27 @@ def request_finished(

remote_num_tokens = request.num_computed_tokens

# Drop trailing blocks allocated beyond num_computed_tokens. With
# speculative decoding the scheduler reserves lookahead slots that
# spill into an extra block when num_computed_tokens is a multiple
# of block_size. Sending it makes remote_block_ids longer than the
# decode allocation, so the decode's suffix-trim
# (_apply_prefix_caching) keeps the never-written lookahead block
# and drops a real one, shifting the mapping -> stale KV reads.
# Clip per group (own block_size) for self-attention groups; leave
# state groups (Mamba/SSM) and others not indexed by token count.
if remote_num_tokens > 0:
kv_cache_groups = self.kv_cache_config.kv_cache_groups
clipped = list(block_ids)
for i, group_spec in enumerate(kv_cache_groups):
spec = group_spec.kv_cache_spec
if not isinstance(spec, (FullAttentionSpec, SlidingWindowSpec)):
continue
num_written_blocks = cdiv(remote_num_tokens, spec.block_size)
if len(clipped[i]) > num_written_blocks:
clipped[i] = clipped[i][:num_written_blocks]
block_ids = tuple(clipped)

return delay_free_blocks, dict(
do_remote_prefill=is_p_node,
do_remote_decode=is_d_node,
Expand Down
Loading