Skip to content
147 changes: 147 additions & 0 deletions tests/v1/spec_decode/test_backup_token_async_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Regression tests for the backup token fix in prepare_next_token_ids_padded.

Fixes #38098: with async scheduling, seq_lens_cpu is inflated by unaccepted
draft token placeholders, causing get_token_id() to return -1.
"""

from __future__ import annotations

import numpy as np
import pytest
import torch


class _FakeRequest:
def __init__(self, prompt_tokens: list[int], output_tokens: list[int]):
self.num_prompt_tokens = len(prompt_tokens)
self._prompt = prompt_tokens
self._output = output_tokens

@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self._output)

def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self._prompt[idx]
out_idx = idx - self.num_prompt_tokens
if out_idx < len(self._output):
return self._output[out_idx]
return -1 # out of range


class _FakeInputBatch:
def __init__(
self,
req_ids: list[str],
num_tokens_no_spec: list[int],
vocab_size: int = 32000,
):
self.req_ids = req_ids
self.num_reqs = len(req_ids)
self.vocab_size = vocab_size
self.num_tokens_no_spec = np.array(num_tokens_no_spec, dtype=np.int64)


def _make_requests(
req_ids: list[str],
prompt_lens: list[int],
output_lens: list[int],
) -> dict[str, _FakeRequest]:
requests = {}
for rid, plen, olen in zip(req_ids, prompt_lens, output_lens):
requests[rid] = _FakeRequest(list(range(plen)), list(range(1000, 1000 + olen)))
return requests


def _backup_buggy(
seq_lens_cpu: torch.Tensor,
requests: dict[str, _FakeRequest],
batch: _FakeInputBatch,
) -> list[int]:
"""Old logic: uses seq_lens_cpu directly (may be inflated)."""
n = batch.num_reqs
return [
requests[batch.req_ids[i]].get_token_id(int(seq_lens_cpu[i])) for i in range(n)
]


def _backup_fixed(
requests: dict[str, _FakeRequest],
batch: _FakeInputBatch,
) -> list[int]:
"""New logic: uses num_tokens_no_spec - 1 (last committed token)."""
n = batch.num_reqs
idx = (batch.num_tokens_no_spec[:n] - 1).tolist()
return [requests[batch.req_ids[i]].get_token_id(int(idx[i])) for i in range(n)]


class TestBackupTokenAsyncSpec:
def test_no_inflation_fixed_returns_last_token(self):
req_ids = ["r0", "r1"]
requests = _make_requests(req_ids, [3, 3], [2, 2])
batch = _FakeInputBatch(req_ids, [5, 5])
# idx = 5-1 = 4 → output[1] = 1001
assert _backup_fixed(requests, batch) == [1001, 1001]

def test_inflation_buggy_returns_placeholder(self):
req_ids = ["r0", "r1"]
requests = _make_requests(req_ids, [3, 3], [2, 2])
batch = _FakeInputBatch(req_ids, [5, 5])
# inflated by 3 spec tokens → idx 8 is out of range
seq_lens = torch.tensor([8, 8], dtype=torch.int64)
assert _backup_buggy(seq_lens, requests, batch) == [-1, -1]

def test_inflation_fixed_returns_correct_token(self):
req_ids = ["r0", "r1"]
requests = _make_requests(req_ids, [3, 3], [2, 2])
batch = _FakeInputBatch(req_ids, [5, 5])
assert _backup_fixed(requests, batch) == [1001, 1001]

def test_mixed_inflation_per_request(self):
req_ids = ["r0", "r1", "r2"]
requests = {
"r0": _FakeRequest([0, 1], [1000, 1001, 1002]),
"r1": _FakeRequest([0, 1, 2, 3], [2000]),
"r2": _FakeRequest([0], [3000, 3001, 3002, 3003]),
}
batch = _FakeInputBatch(req_ids, [5, 5, 5])
seq_lens = torch.tensor([7, 9, 5], dtype=torch.int64)

assert _backup_buggy(seq_lens, requests, batch) == [-1, -1, -1]
assert _backup_fixed(requests, batch) == [1002, 2000, 3003]

def test_prefill_only_request(self):
"""No output tokens yet — backup should be the last prompt token."""
req_ids = ["r0"]
requests = {"r0": _FakeRequest([10, 20, 30], [])}
batch = _FakeInputBatch(req_ids, [3])
# idx = 3-1 = 2 → prompt[2] = 30
assert _backup_fixed(requests, batch) == [30]

@pytest.mark.parametrize("num_spec_tokens", [1, 2, 3, 4, 5])
def test_various_spec_token_counts(self, num_spec_tokens: int):
req_ids = ["r0"]
requests = {"r0": _FakeRequest([0, 1, 2], list(range(1000, 1005)))}
batch = _FakeInputBatch(req_ids, [8])
# idx = 8-1 = 7 → output[4] = 1004
assert _backup_fixed(requests, batch) == [1004]

def test_buggy_code_was_always_off_by_one(self):
"""The original code used seq_len as index, which is always one past
the end of output_token_ids even without async inflation."""
req_ids = ["r0"]
requests = {"r0": _FakeRequest([0, 1, 2], [1000, 1001])}
batch = _FakeInputBatch(req_ids, [5])

# no inflation: seq_len == num_tokens == 5 → idx 5 is out of range
seq_lens = torch.tensor([5], dtype=torch.int64)
assert _backup_buggy(seq_lens, requests, batch) == [-1]
assert _backup_fixed(requests, batch) == [1001]

# with inflation: still -1, fixed still correct
seq_lens_inf = torch.tensor([8], dtype=torch.int64)
assert _backup_buggy(seq_lens_inf, requests, batch) == [-1]
assert _backup_fixed(requests, batch) == [1001]
16 changes: 4 additions & 12 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from unittest import mock

import numpy as np
import pytest
import torch

Expand Down Expand Up @@ -127,16 +128,14 @@ def test_prepare_next_token_ids():

num_requests = 4
num_speculative_tokens = 4
batch_spec = BatchSpec(
seq_lens=[num_speculative_tokens + 1] * num_requests,
query_lens=[num_speculative_tokens + 1] * num_requests,
)

req_ids = [f"req_{i + 1}" for i in range(num_requests)]
mock_input_batch = mock.MagicMock(spec=InputBatch)
mock_input_batch.req_ids = req_ids
mock_input_batch.num_reqs = num_requests
mock_input_batch.vocab_size = 100
mock_input_batch.num_tokens_no_spec = np.array(
[num_speculative_tokens + 1] * num_requests
)

mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
mock_requests = {}
Expand Down Expand Up @@ -181,19 +180,12 @@ def test_prepare_next_token_ids():

assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)

common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=BLOCK_SIZE,
device=device,
)

expected_valid_sampled_tokens_count = torch.tensor(
[2, 5, 0, 0], dtype=torch.int32, device=device
)

next_token_ids_from_padded, valid_sampled_tokens_count = (
proposer.prepare_next_token_ids_padded(
common_attn_metadata.seq_lens_cpu,
sampled_token_ids_tensor,
mock_requests,
mock_input_batch,
Expand Down
14 changes: 2 additions & 12 deletions tests/v1/spec_decode/test_extract_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from unittest import mock

import numpy as np
import pytest
import torch

Expand Down Expand Up @@ -132,16 +133,12 @@ def test_prepare_next_token_ids_padded():
device = torch.device(current_platform.device_type)

num_requests = 4
batch_spec = BatchSpec(
seq_lens=[5] * num_requests,
query_lens=[5] * num_requests,
)

req_ids = [f"req_{i + 1}" for i in range(num_requests)]
mock_input_batch = mock.MagicMock(spec=InputBatch)
mock_input_batch.req_ids = req_ids
mock_input_batch.num_reqs = num_requests
mock_input_batch.vocab_size = 100
mock_input_batch.num_tokens_no_spec = np.array([5] * num_requests)

mock_requests = {}
for req_id in req_ids:
Expand Down Expand Up @@ -174,20 +171,13 @@ def test_prepare_next_token_ids_padded():

proposer = _create_proposer(num_speculative_tokens=1)

common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)

# valid_sampled_tokens_count tracks if token is valid (not -1 and in vocab range)
# It doesn't depend on whether the request is discarded
expected_valid_sampled_tokens_count = torch.tensor(
[1, 1, 0, 1], dtype=torch.int32, device=device
)

next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
common_attn_metadata.seq_lens_cpu,
sampled_token_ids,
mock_requests,
mock_input_batch,
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,6 @@ def prepare_next_token_ids_cpu(

def prepare_next_token_ids_padded(
self,
seq_lens_cpu: torch.Tensor,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
Expand All @@ -874,7 +873,7 @@ def prepare_next_token_ids_padded(
"""
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
seq_lens_list = seq_lens_cpu[:num_reqs].tolist()
seq_lens_list = (gpu_input_batch.num_tokens_no_spec[:num_reqs] - 1).tolist()
self.backup_next_token_ids.np[:num_reqs] = np.array(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/spec_decode/extract_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def _build_attn_metadata_builder(

def prepare_next_token_ids_padded(
self,
seq_lens: torch.Tensor,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
Expand All @@ -303,7 +302,7 @@ def prepare_next_token_ids_padded(
device = sampled_token_ids.device

# Compute backup tokens for discarded / invalid requests
seq_lens_list = seq_lens[:num_reqs].tolist()
seq_lens_list = (gpu_input_batch.num_tokens_no_spec[:num_reqs] - 1).tolist()
backup_tokens_gpu = torch.tensor(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
Expand Down
30 changes: 22 additions & 8 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@
if TYPE_CHECKING:
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager
from vllm.v1.worker.gpu.mm.encoder_cudagraph import EncoderCudaGraphManager

logger = init_logger(__name__)

Expand Down Expand Up @@ -1939,9 +1939,24 @@ def _prepare_inputs(
# _update_states_after_model_execute for hybrid models).
if self.num_accepted_tokens_event is not None:
self.num_accepted_tokens_event.synchronize()
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
)
# Async mode: condense() reordered indices, use prev_positions mapping
if self.use_async_scheduling and prev_req_id_to_index:
prev_idx = self.prev_positions.np[:num_reqs]
new_mask = prev_idx < 0
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[
np.where(new_mask, 0, prev_idx)
]
)
self.num_accepted_tokens.np[:num_reqs][new_mask] = 1
self.input_batch.num_accepted_tokens_cpu[:num_reqs] = (
self.num_accepted_tokens.np[:num_reqs]
)
else:
# Non-async mode: use values directly
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
)
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
else:
Expand Down Expand Up @@ -4220,7 +4235,6 @@ def propose_draft_token_ids(sampled_token_ids):
assert spec_decode_common_attn_metadata is not None
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
self.optimistic_seq_lens_cpu,
sampled_token_ids,
self.requests,
self.input_batch,
Expand Down Expand Up @@ -4587,7 +4601,6 @@ def propose_draft_token_ids(
)
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
self.optimistic_seq_lens_cpu,
sampled_token_ids,
self.requests,
self.input_batch,
Expand Down Expand Up @@ -4632,7 +4645,6 @@ def propose_draft_token_ids(
)
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
self.optimistic_seq_lens_cpu,
sampled_token_ids,
self.requests,
self.input_batch,
Expand Down Expand Up @@ -5990,7 +6002,9 @@ def capture_model(self) -> int:
SupportsEncoderCudaGraph,
supports_encoder_cudagraph,
)
from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager
from vllm.v1.worker.gpu.mm.encoder_cudagraph import (
EncoderCudaGraphManager,
)

raw_model = self.get_model()
if supports_encoder_cudagraph(raw_model):
Expand Down
Loading