diff --git a/tests/v1/spec_decode/test_backup_token_async_spec.py b/tests/v1/spec_decode/test_backup_token_async_spec.py new file mode 100644 index 000000000000..9340503ea4f2 --- /dev/null +++ b/tests/v1/spec_decode/test_backup_token_async_spec.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Regression tests for the backup token fix in prepare_next_token_ids_padded. + +Fixes #38098: with async scheduling, seq_lens_cpu is inflated by unaccepted +draft token placeholders, causing get_token_id() to return -1. +""" + +from __future__ import annotations + +import numpy as np +import pytest +import torch + + +class _FakeRequest: + def __init__(self, prompt_tokens: list[int], output_tokens: list[int]): + self.num_prompt_tokens = len(prompt_tokens) + self._prompt = prompt_tokens + self._output = output_tokens + + @property + def num_tokens(self) -> int: + return self.num_prompt_tokens + len(self._output) + + def get_token_id(self, idx: int) -> int: + if idx < self.num_prompt_tokens: + return self._prompt[idx] + out_idx = idx - self.num_prompt_tokens + if out_idx < len(self._output): + return self._output[out_idx] + return -1 # out of range + + +class _FakeInputBatch: + def __init__( + self, + req_ids: list[str], + num_tokens_no_spec: list[int], + vocab_size: int = 32000, + ): + self.req_ids = req_ids + self.num_reqs = len(req_ids) + self.vocab_size = vocab_size + self.num_tokens_no_spec = np.array(num_tokens_no_spec, dtype=np.int64) + + +def _make_requests( + req_ids: list[str], + prompt_lens: list[int], + output_lens: list[int], +) -> dict[str, _FakeRequest]: + requests = {} + for rid, plen, olen in zip(req_ids, prompt_lens, output_lens): + requests[rid] = _FakeRequest(list(range(plen)), list(range(1000, 1000 + olen))) + return requests + + +def _backup_buggy( + seq_lens_cpu: torch.Tensor, + requests: dict[str, _FakeRequest], + batch: _FakeInputBatch, +) -> list[int]: + """Old logic: uses seq_lens_cpu directly (may be inflated).""" + n = batch.num_reqs + return [ + requests[batch.req_ids[i]].get_token_id(int(seq_lens_cpu[i])) for i in range(n) + ] + + +def _backup_fixed( + requests: dict[str, _FakeRequest], + batch: _FakeInputBatch, +) -> list[int]: + """New logic: uses num_tokens_no_spec - 1 (last committed token).""" + n = batch.num_reqs + idx = (batch.num_tokens_no_spec[:n] - 1).tolist() + return [requests[batch.req_ids[i]].get_token_id(int(idx[i])) for i in range(n)] + + +class TestBackupTokenAsyncSpec: + def test_no_inflation_fixed_returns_last_token(self): + req_ids = ["r0", "r1"] + requests = _make_requests(req_ids, [3, 3], [2, 2]) + batch = _FakeInputBatch(req_ids, [5, 5]) + # idx = 5-1 = 4 → output[1] = 1001 + assert _backup_fixed(requests, batch) == [1001, 1001] + + def test_inflation_buggy_returns_placeholder(self): + req_ids = ["r0", "r1"] + requests = _make_requests(req_ids, [3, 3], [2, 2]) + batch = _FakeInputBatch(req_ids, [5, 5]) + # inflated by 3 spec tokens → idx 8 is out of range + seq_lens = torch.tensor([8, 8], dtype=torch.int64) + assert _backup_buggy(seq_lens, requests, batch) == [-1, -1] + + def test_inflation_fixed_returns_correct_token(self): + req_ids = ["r0", "r1"] + requests = _make_requests(req_ids, [3, 3], [2, 2]) + batch = _FakeInputBatch(req_ids, [5, 5]) + assert _backup_fixed(requests, batch) == [1001, 1001] + + def test_mixed_inflation_per_request(self): + req_ids = ["r0", "r1", "r2"] + requests = { + "r0": _FakeRequest([0, 1], [1000, 1001, 1002]), + "r1": _FakeRequest([0, 1, 2, 3], [2000]), + "r2": _FakeRequest([0], [3000, 3001, 3002, 3003]), + } + batch = _FakeInputBatch(req_ids, [5, 5, 5]) + seq_lens = torch.tensor([7, 9, 5], dtype=torch.int64) + + assert _backup_buggy(seq_lens, requests, batch) == [-1, -1, -1] + assert _backup_fixed(requests, batch) == [1002, 2000, 3003] + + def test_prefill_only_request(self): + """No output tokens yet — backup should be the last prompt token.""" + req_ids = ["r0"] + requests = {"r0": _FakeRequest([10, 20, 30], [])} + batch = _FakeInputBatch(req_ids, [3]) + # idx = 3-1 = 2 → prompt[2] = 30 + assert _backup_fixed(requests, batch) == [30] + + @pytest.mark.parametrize("num_spec_tokens", [1, 2, 3, 4, 5]) + def test_various_spec_token_counts(self, num_spec_tokens: int): + req_ids = ["r0"] + requests = {"r0": _FakeRequest([0, 1, 2], list(range(1000, 1005)))} + batch = _FakeInputBatch(req_ids, [8]) + # idx = 8-1 = 7 → output[4] = 1004 + assert _backup_fixed(requests, batch) == [1004] + + def test_buggy_code_was_always_off_by_one(self): + """The original code used seq_len as index, which is always one past + the end of output_token_ids even without async inflation.""" + req_ids = ["r0"] + requests = {"r0": _FakeRequest([0, 1, 2], [1000, 1001])} + batch = _FakeInputBatch(req_ids, [5]) + + # no inflation: seq_len == num_tokens == 5 → idx 5 is out of range + seq_lens = torch.tensor([5], dtype=torch.int64) + assert _backup_buggy(seq_lens, requests, batch) == [-1] + assert _backup_fixed(requests, batch) == [1001] + + # with inflation: still -1, fixed still correct + seq_lens_inf = torch.tensor([8], dtype=torch.int64) + assert _backup_buggy(seq_lens_inf, requests, batch) == [-1] + assert _backup_fixed(requests, batch) == [1001] diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 496ff85f7562..c1082202448e 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -3,6 +3,7 @@ from unittest import mock +import numpy as np import pytest import torch @@ -127,16 +128,14 @@ def test_prepare_next_token_ids(): num_requests = 4 num_speculative_tokens = 4 - batch_spec = BatchSpec( - seq_lens=[num_speculative_tokens + 1] * num_requests, - query_lens=[num_speculative_tokens + 1] * num_requests, - ) - req_ids = [f"req_{i + 1}" for i in range(num_requests)] mock_input_batch = mock.MagicMock(spec=InputBatch) mock_input_batch.req_ids = req_ids mock_input_batch.num_reqs = num_requests mock_input_batch.vocab_size = 100 + mock_input_batch.num_tokens_no_spec = np.array( + [num_speculative_tokens + 1] * num_requests + ) mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids} mock_requests = {} @@ -181,19 +180,12 @@ def test_prepare_next_token_ids(): assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor) - common_attn_metadata = create_common_attn_metadata( - batch_spec, - block_size=BLOCK_SIZE, - device=device, - ) - expected_valid_sampled_tokens_count = torch.tensor( [2, 5, 0, 0], dtype=torch.int32, device=device ) next_token_ids_from_padded, valid_sampled_tokens_count = ( proposer.prepare_next_token_ids_padded( - common_attn_metadata.seq_lens_cpu, sampled_token_ids_tensor, mock_requests, mock_input_batch, diff --git a/tests/v1/spec_decode/test_extract_hidden_states.py b/tests/v1/spec_decode/test_extract_hidden_states.py index 27b2a53c1849..9f9758b829bf 100644 --- a/tests/v1/spec_decode/test_extract_hidden_states.py +++ b/tests/v1/spec_decode/test_extract_hidden_states.py @@ -3,6 +3,7 @@ from unittest import mock +import numpy as np import pytest import torch @@ -132,16 +133,12 @@ def test_prepare_next_token_ids_padded(): device = torch.device(current_platform.device_type) num_requests = 4 - batch_spec = BatchSpec( - seq_lens=[5] * num_requests, - query_lens=[5] * num_requests, - ) - req_ids = [f"req_{i + 1}" for i in range(num_requests)] mock_input_batch = mock.MagicMock(spec=InputBatch) mock_input_batch.req_ids = req_ids mock_input_batch.num_reqs = num_requests mock_input_batch.vocab_size = 100 + mock_input_batch.num_tokens_no_spec = np.array([5] * num_requests) mock_requests = {} for req_id in req_ids: @@ -174,12 +171,6 @@ def test_prepare_next_token_ids_padded(): proposer = _create_proposer(num_speculative_tokens=1) - common_attn_metadata = create_common_attn_metadata( - batch_spec, - block_size=16, - device=device, - ) - # valid_sampled_tokens_count tracks if token is valid (not -1 and in vocab range) # It doesn't depend on whether the request is discarded expected_valid_sampled_tokens_count = torch.tensor( @@ -187,7 +178,6 @@ def test_prepare_next_token_ids_padded(): ) next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded( - common_attn_metadata.seq_lens_cpu, sampled_token_ids, mock_requests, mock_input_batch, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a03b707dd347..a068f33b9908 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -859,7 +859,6 @@ def prepare_next_token_ids_cpu( def prepare_next_token_ids_padded( self, - seq_lens_cpu: torch.Tensor, sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, @@ -874,7 +873,7 @@ def prepare_next_token_ids_padded( """ # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs - seq_lens_list = seq_lens_cpu[:num_reqs].tolist() + seq_lens_list = (gpu_input_batch.num_tokens_no_spec[:num_reqs] - 1).tolist() self.backup_next_token_ids.np[:num_reqs] = np.array( [ requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i]) diff --git a/vllm/v1/spec_decode/extract_hidden_states.py b/vllm/v1/spec_decode/extract_hidden_states.py index e26fa768a324..eb559845fa77 100644 --- a/vllm/v1/spec_decode/extract_hidden_states.py +++ b/vllm/v1/spec_decode/extract_hidden_states.py @@ -286,7 +286,6 @@ def _build_attn_metadata_builder( def prepare_next_token_ids_padded( self, - seq_lens: torch.Tensor, sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, @@ -303,7 +302,7 @@ def prepare_next_token_ids_padded( device = sampled_token_ids.device # Compute backup tokens for discarded / invalid requests - seq_lens_list = seq_lens[:num_reqs].tolist() + seq_lens_list = (gpu_input_batch.num_tokens_no_spec[:num_reqs] - 1).tolist() backup_tokens_gpu = torch.tensor( [ requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i]) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a950827879ad..796af571da99 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -210,7 +210,7 @@ if TYPE_CHECKING: from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.spec_decode.ngram_proposer import NgramProposer - from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager + from vllm.v1.worker.gpu.mm.encoder_cudagraph import EncoderCudaGraphManager logger = init_logger(__name__) @@ -1939,9 +1939,24 @@ def _prepare_inputs( # _update_states_after_model_execute for hybrid models). if self.num_accepted_tokens_event is not None: self.num_accepted_tokens_event.synchronize() - self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs] - ) + # Async mode: condense() reordered indices, use prev_positions mapping + if self.use_async_scheduling and prev_req_id_to_index: + prev_idx = self.prev_positions.np[:num_reqs] + new_mask = prev_idx < 0 + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[ + np.where(new_mask, 0, prev_idx) + ] + ) + self.num_accepted_tokens.np[:num_reqs][new_mask] = 1 + self.input_batch.num_accepted_tokens_cpu[:num_reqs] = ( + self.num_accepted_tokens.np[:num_reqs] + ) + else: + # Non-async mode: use values directly + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() else: @@ -4220,7 +4235,6 @@ def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - self.optimistic_seq_lens_cpu, sampled_token_ids, self.requests, self.input_batch, @@ -4587,7 +4601,6 @@ def propose_draft_token_ids( ) next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - self.optimistic_seq_lens_cpu, sampled_token_ids, self.requests, self.input_batch, @@ -4632,7 +4645,6 @@ def propose_draft_token_ids( ) next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - self.optimistic_seq_lens_cpu, sampled_token_ids, self.requests, self.input_batch, @@ -5990,7 +6002,9 @@ def capture_model(self) -> int: SupportsEncoderCudaGraph, supports_encoder_cudagraph, ) - from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager + from vllm.v1.worker.gpu.mm.encoder_cudagraph import ( + EncoderCudaGraphManager, + ) raw_model = self.get_model() if supports_encoder_cudagraph(raw_model):