From c9c5b7e3fc8c8bafa558ca96c688fda2eee88acf Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 31 May 2026 12:34:53 -0700 Subject: [PATCH] [Bugfix][KV Connector] NIXL PD: don't transfer speculative-decode lookahead blocks Fixes a KV-corruption bug in NIXL prefill/decode disaggregation with speculative decoding, plus a tokenization bug in the PD acceptance test that was masking it. Bug 1 (correctness): prefill transfers lookahead blocks In request_finished, the prefill node reports remote_num_tokens = num_computed_tokens but sends its full block allocation, which with speculative decoding includes trailing lookahead-reservation blocks. When num_prompt_tokens % block_size == 0, the lookahead slot spills into an extra block, so len(remote_block_ids) == len(local_block_ids) + 1 on the decode side. _apply_prefix_caching then does remote[-num_local:] -- which assumes the surplus is an already-cached prefix and keeps the remote suffix. That drops the real first block and keeps the never-written lookahead block, shifting the entire block-to-block mapping. The decode node ends up attending to never-written KV (on both the target and the EAGLE3 drafter layers), producing wrong outputs for the affected requests. Fix: clip the transferred block_ids per KV cache group, using each group's own block_size, to the blocks covering num_computed_tokens. Self-attention groups (full / sliding-window, incl. MLA/sink subclasses) are clipped; state groups (Mamba/SSM) and any other spec whose length is not indexed by token count are passed through unchanged, so hybrid models are handled correctly. Bug 2 (test): double-BOS in the PD acceptance test test_spec_decode_acceptance.py sends already-chat-templated prompts (which contain the BOS token) through the completions API, which prepends another BOS by default. This double-BOS lowered acceptance ~5% versus the standalone baselines (which tokenize with add_special_tokens=False), affecting single-engine and PD identically. Set add_special_tokens=False so the test compares like-for-like. Test plan: - New unit tests in test_remote_decode_lifecycle.py (CPU): test_remote_decode_drops_lookahead_blocks (parametrized over 0/1/2 trailing lookahead blocks) and test_remote_decode_lookahead_clip_is_per_group (hybrid Mamba + attention with a per-group block_size). Both directly exercise request_finished; they fail without the fix and pass with it. Full file: 8 passed. - NIXL PD + EAGLE3 acceptance (Llama-3.1-8B, FLASH_ATTN, 2xGPU): per-position acceptance now 0.728 / 0.524 / 0.363 (vs baseline 0.730 / 0.521 / 0.354) and PASSES; V2 matches V1. Before the fix the misalignment dropped pos-2 acceptance and produced divergent outputs for prompt_len % block_size == 0 requests. Not a duplicate: checked open NIXL/PD/spec-decode PRs (#43151, #42554, #41169, #35264); none addresses the lookahead-block transfer / block-mapping misalignment. AI assistance (Claude) was used; all changes reviewed and tested by the submitter. Signed-off-by: Nick Hill Co-Authored-By: Claude Opus 4.8 (1M context) --- .../test_spec_decode_acceptance.py | 4 + .../unit/test_remote_decode_lifecycle.py | 128 ++++++++++++++++++ .../kv_connector/v1/nixl/scheduler.py | 21 +++ 3 files changed, 153 insertions(+) diff --git a/tests/v1/kv_connector/nixl_integration/test_spec_decode_acceptance.py b/tests/v1/kv_connector/nixl_integration/test_spec_decode_acceptance.py index c86a407ff8e3..15f386f5f5a8 100644 --- a/tests/v1/kv_connector/nixl_integration/test_spec_decode_acceptance.py +++ b/tests/v1/kv_connector/nixl_integration/test_spec_decode_acceptance.py @@ -158,6 +158,10 @@ def test_spec_decode_acceptance_length(): max_tokens=DEFAULT_OUTPUT_LEN, temperature=0.0, top_p=1.0, + # Prompts are already chat-templated (contain BOS); avoid the + # completions API prepending a second BOS, which would lower + # acceptance ~5% vs the add_special_tokens=False standalone baselines. + extra_body={"add_special_tokens": False}, ) if i < 3: text = resp.choices[0].text.strip()[:100] diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index b656e0809543..d1ab79e41454 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -3,7 +3,17 @@ import copy import pytest +import torch +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import ( + NixlConnectorScheduler, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + MambaSpec, +) from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput from vllm.v1.request import FinishReason, RequestStatus @@ -222,6 +232,124 @@ def test_prefix_cache_lifecycle(): assert_scheduler_empty(scheduler) +def _make_nixl_connector_scheduler( + vllm_config, kv_cache_groups=None +) -> NixlConnectorScheduler: + """Build a standalone NIXL connector-scheduler for directly exercising + ``request_finished``. Defaults to a single full-attention KV cache group.""" + if kv_cache_groups is None: + block_size = vllm_config.cache_config.block_size + kv_cache_groups = [ + KVCacheGroupSpec( + ["layer"], + FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), + ) + ] + kv_cache_config = KVCacheConfig( + num_blocks=10000, + kv_cache_tensors=[], + kv_cache_groups=kv_cache_groups, + ) + return NixlConnectorScheduler(vllm_config, "test-engine-id", kv_cache_config) + + +@pytest.mark.parametrize("extra_lookahead_blocks", [0, 1, 2]) +def test_remote_decode_drops_lookahead_blocks(extra_lookahead_blocks): + """Regression test: request_finished must transfer exactly the blocks + holding the computed KV, not the spec-decode lookahead reservation blocks + allocated past num_computed_tokens. Sending an extra block lets the decode + node's suffix-trim misalign the sequence and read stale KV. + """ + vllm_config = create_vllm_config() + connector = _make_nixl_connector_scheduler(vllm_config) + + block_size = vllm_config.cache_config.block_size + # Multiple of block_size: the worst case where the lookahead slot needs a + # brand-new block. Allocate prompt blocks + the lookahead reservation. + num_computed_tokens = 4 * block_size + num_prompt_blocks = num_computed_tokens // block_size # == 4 + allocated_block_ids = list(range(1, num_prompt_blocks + extra_lookahead_blocks + 1)) + + request = create_request( + request_id=1, + block_size=block_size, + num_tokens=num_computed_tokens, + do_remote_decode=True, + ) + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + request.num_computed_tokens = num_computed_tokens + + delay_free_blocks, params = connector.request_finished( + request, (allocated_block_ids,) + ) + + assert delay_free_blocks is True + assert params is not None + # Trailing lookahead blocks dropped, regardless of how many were allocated. + assert params["remote_block_ids"] == ([1, 2, 3, 4],) + assert len(params["remote_block_ids"][0]) == num_prompt_blocks + assert params["remote_num_tokens"] == num_computed_tokens + + +def test_remote_decode_lookahead_clip_is_per_group(): + """Clipping is per-group with each group's own block_size: in a hybrid + model the attention group is clipped while a Mamba/SSM state group is left + untouched. The attention group uses a block_size != the global one, so a + global-block_size implementation would clip it incorrectly. + """ + vllm_config = create_vllm_config() + global_block_size = vllm_config.cache_config.block_size # 16 + attn_block_size = 2 * global_block_size # 32 + + kv_cache_groups = [ + KVCacheGroupSpec( + ["mamba_layer"], + MambaSpec( + block_size=global_block_size, + shapes=((1,),), + dtypes=(torch.float32,), + ), + ), + KVCacheGroupSpec( + ["attn_layer"], + FullAttentionSpec( + block_size=attn_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), + ), + ] + connector = _make_nixl_connector_scheduler(vllm_config, kv_cache_groups) + + # 64 tokens => 2 attn blocks at block_size 32, + 1 lookahead block. + # (cdiv(64, 16) == 4 would not clip, so this fails with the global size.) + num_computed_tokens = 2 * attn_block_size # 64 + request = create_request( + request_id=1, + block_size=attn_block_size, + num_tokens=num_computed_tokens, + do_remote_decode=True, + ) + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + request.num_computed_tokens = num_computed_tokens + + # group 0: Mamba state block; group 1: 2 prompt blocks + 1 lookahead block. + mamba_block_ids = [101] + attn_block_ids = [1, 2, 3] + _, params = connector.request_finished(request, (mamba_block_ids, attn_block_ids)) + + # Mamba group passed through; attention group clipped at its own block_size. + remote_block_ids = params["remote_block_ids"] + assert remote_block_ids[0] == [101] + assert remote_block_ids[1] == [1, 2] + + def test_abort_during_kv_transfer(): """Test aborting request does not release blocks for remote decode.""" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py index b2122ed0d30b..1293a35cce82 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py @@ -661,6 +661,27 @@ def request_finished( remote_num_tokens = request.num_computed_tokens + # Drop trailing blocks allocated beyond num_computed_tokens. With + # speculative decoding the scheduler reserves lookahead slots that + # spill into an extra block when num_computed_tokens is a multiple + # of block_size. Sending it makes remote_block_ids longer than the + # decode allocation, so the decode's suffix-trim + # (_apply_prefix_caching) keeps the never-written lookahead block + # and drops a real one, shifting the mapping -> stale KV reads. + # Clip per group (own block_size) for self-attention groups; leave + # state groups (Mamba/SSM) and others not indexed by token count. + if remote_num_tokens > 0: + kv_cache_groups = self.kv_cache_config.kv_cache_groups + clipped = list(block_ids) + for i, group_spec in enumerate(kv_cache_groups): + spec = group_spec.kv_cache_spec + if not isinstance(spec, (FullAttentionSpec, SlidingWindowSpec)): + continue + num_written_blocks = cdiv(remote_num_tokens, spec.block_size) + if len(clipped[i]) > num_written_blocks: + clipped[i] = clipped[i][:num_written_blocks] + block_ids = tuple(clipped) + return delay_free_blocks, dict( do_remote_prefill=is_p_node, do_remote_decode=is_d_node,