From eeef2379c740fe60d25333e1e4286792211e4e0d Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Thu, 26 Feb 2026 03:21:13 +0000 Subject: [PATCH 1/4] Works Signed-off-by: Andy Lo --- tests/v1/spec_decode/test_rejection_sample.py | 299 ++++++++++++++++++ vllm/v1/worker/gpu/input_batch.py | 4 + vllm/v1/worker/gpu/model_runner.py | 67 ++-- vllm/v1/worker/gpu/sample/output.py | 1 + vllm/v1/worker/gpu/sample/sampler.py | 101 +++++- .../gpu/spec_decode/eagle/speculator.py | 46 ++- vllm/v1/worker/gpu/spec_decode/outputs.py | 13 + .../gpu/spec_decode/rejection_sample.py | 166 +++++++++- vllm/v1/worker/gpu/states.py | 7 + 9 files changed, 638 insertions(+), 66 deletions(-) create mode 100644 tests/v1/spec_decode/test_rejection_sample.py create mode 100644 vllm/v1/worker/gpu/spec_decode/outputs.py diff --git a/tests/v1/spec_decode/test_rejection_sample.py b/tests/v1/spec_decode/test_rejection_sample.py new file mode 100644 index 000000000000..af7e188cddaf --- /dev/null +++ b/tests/v1/spec_decode/test_rejection_sample.py @@ -0,0 +1,299 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the sample_recovered_and_bonus_tokens Triton kernel. + +The kernel implements Gumbel-max sampling for speculative decoding: + - For each draft token position: sample from max(0, target_probs - draft_probs) + - For each bonus token (last per request): sample directly from target_probs +""" + +import pytest +import torch +import torch.nn.functional as F + +from vllm.platforms import current_platform +from vllm.v1.worker.gpu.spec_decode.rejection_sample import ( + sample_recovered_and_bonus_tokens, +) + +DEVICE = current_platform.device_type + +# The Triton kernel processes the vocab in blocks of this size. +# Vocab sizes that are not multiples of this are the most likely to trigger +# the NaN bug fixed in this file (padding positions hit -inf - (-inf) = NaN). +KERNEL_BLOCK_SIZE = 1024 + + +def make_inputs( + num_draft_tokens_per_req: list[int], + vocab_size: int, + target_probs: torch.Tensor | None = None, + draft_probs: torch.Tensor | None = None, + seeds: torch.Tensor | None = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """Build the input tensors needed by sample_recovered_and_bonus_tokens. + + Each request contributes (num_draft + 1) token rows to target_probs: + num_draft recovered-token rows followed by one bonus-token row. + draft_probs has one row per draft token only (no bonus row). + """ + num_reqs = len(num_draft_tokens_per_req) + # Each request has num_draft draft positions + 1 bonus position. + num_tokens = sum(n + 1 for n in num_draft_tokens_per_req) + num_draft_tokens = sum(num_draft_tokens_per_req) + + # cu_num_logits[i+1] - cu_num_logits[i] = num_draft_tokens_per_req[i] + 1 + cu = [0] + for n in num_draft_tokens_per_req: + cu.append(cu[-1] + n + 1) + cu_num_logits = torch.tensor(cu, dtype=torch.int32, device=DEVICE) + + # idx_mapping[token_idx] = req_idx for every token belonging to that req. + idx_mapping = torch.zeros(num_tokens, dtype=torch.int32, device=DEVICE) + for req_idx, n in enumerate(num_draft_tokens_per_req): + start = cu[req_idx] + end = cu[req_idx + 1] + idx_mapping[start:end] = req_idx + + if seeds is None: + seeds = torch.arange(num_reqs, dtype=torch.int64, device=DEVICE) + + # pos is used as the RNG offset; distinct values give diverse random draws. + pos = torch.arange(num_tokens, dtype=torch.int32, device=DEVICE) + + if target_probs is None: + target_probs = F.softmax( + torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE), + dim=-1, + ) + if draft_probs is None: + if num_draft_tokens > 0: + draft_probs = F.softmax( + torch.rand( + num_draft_tokens, vocab_size, dtype=torch.float32, device=DEVICE + ), + dim=-1, + ) + else: + draft_probs = torch.empty(0, vocab_size, dtype=torch.float32, device=DEVICE) + + return target_probs, draft_probs, cu_num_logits, idx_mapping, seeds, pos + + +def peaked_probs(num_rows: int, vocab_size: int, token_ids: list[int]) -> torch.Tensor: + """Return near-one-hot probability distributions. + + Row i has almost all probability mass on token_ids[i], making sampling + effectively deterministic regardless of the Gumbel noise realisation. + """ + assert len(token_ids) == num_rows + probs = torch.full( + (num_rows, vocab_size), 1e-10, dtype=torch.float32, device=DEVICE + ) + for i, t in enumerate(token_ids): + probs[i, t] = 1.0 + return probs / probs.sum(dim=-1, keepdim=True) + + +# --------------------------------------------------------------------------- +# Range / regression tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "vocab_size", + [ + 100, + 1000, + KERNEL_BLOCK_SIZE, # aligned: last block is fully valid + KERNEL_BLOCK_SIZE + 1, # one valid token in last block + KERNEL_BLOCK_SIZE * 2 - 1, # one invalid token in last block + 128256, # Llama-3 vocab; 128256 % 1024 == 256 (256 valid tokens in last block) + 128257, # Llama-3 vocab + 1; 257 valid tokens in last block + ], +) +@pytest.mark.parametrize( + "num_draft_tokens_per_req", + [ + [1], + [3], + [1, 2, 3], + ], +) +def test_output_in_vocab_range(vocab_size: int, num_draft_tokens_per_req: list[int]): + """All sampled token IDs must satisfy 0 <= id < vocab_size. + + This is a regression test for a NaN bug: when the vocab size is not a + multiple of BLOCK_SIZE the out-of-bounds padding positions in the last + block were loaded as -inf for both target_probs and draft_probs. The + subtraction -inf - (-inf) = NaN, which then propagated through tl.max + and could select an out-of-vocab index. + """ + args = make_inputs(num_draft_tokens_per_req, vocab_size) + sampled = sample_recovered_and_bonus_tokens(*args) + + num_tokens = sum(n + 1 for n in num_draft_tokens_per_req) + assert sampled.shape == (num_tokens,) + assert (sampled >= 0).all(), f"Negative token IDs: {sampled}" + assert (sampled < vocab_size).all(), ( + f"Out-of-vocab token IDs (vocab_size={vocab_size}): max={sampled.max().item()}" + ) + + +# --------------------------------------------------------------------------- +# Correctness: bonus vs. recovered token distinction +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("num_draft_per_req", [0, 1, 3]) +def test_bonus_token_samples_from_target_probs(num_draft_per_req: int): + """The bonus token (last position per request) samples from target_probs. + + It must NOT be adjusted by draft_probs. We verify this by concentrating + the target mass on a known token at every bonus position and checking the + output, regardless of what draft_probs contains. + """ + vocab_size = 100 + num_reqs = 4 + num_draft_tokens_per_req = [num_draft_per_req] * num_reqs + num_tokens = sum(n + 1 for n in num_draft_tokens_per_req) + num_draft_tokens = sum(num_draft_tokens_per_req) + + cu = [0] + [ + sum(n + 1 for n in num_draft_tokens_per_req[: i + 1]) for i in range(num_reqs) + ] + + # Assign a distinct expected token to the bonus position of each request. + expected_bonus = list(range(num_reqs)) # token 0, 1, 2, 3 + target_token_ids = [0] * num_tokens + for req_idx in range(num_reqs): + bonus_pos = cu[req_idx + 1] - 1 # last position of the request + target_token_ids[bonus_pos] = expected_bonus[req_idx] + + target_probs = peaked_probs(num_tokens, vocab_size, target_token_ids) + # draft_probs is random; for bonus positions it should be ignored. + draft_probs = ( + F.softmax( + torch.rand( + num_draft_tokens, vocab_size, dtype=torch.float32, device=DEVICE + ), + dim=-1, + ) + if num_draft_tokens > 0 + else torch.empty(0, vocab_size, dtype=torch.float32, device=DEVICE) + ) + + _, _, cu_num_logits, idx_mapping, seeds, pos = make_inputs( + num_draft_tokens_per_req, vocab_size + ) + sampled = sample_recovered_and_bonus_tokens( + target_probs, draft_probs, cu_num_logits, idx_mapping, seeds, pos + ) + + for req_idx in range(num_reqs): + bonus_pos = cu[req_idx + 1] - 1 + got = sampled[bonus_pos].item() + assert got == expected_bonus[req_idx], ( + f"Request {req_idx}: bonus token expected {expected_bonus[req_idx]}, " + f"got {got}" + ) + + +def test_recovered_token_samples_from_target_minus_draft(): + """Recovered tokens sample from (target_probs - draft_probs). + + When target_probs ≈ one_hot(A) and draft_probs ≈ one_hot(B) with A ≠ B, + the adjusted distribution has positive mass only at A, so the sampled + token must be A. + """ + vocab_size = 100 + num_draft_tokens_per_req = [3] + num_tokens = 4 # 3 draft + 1 bonus + num_draft_tokens = 3 + + target_token = 10 + draft_token = 20 # different from target_token + + target_probs = peaked_probs(num_tokens, vocab_size, [target_token] * num_tokens) + draft_probs = peaked_probs( + num_draft_tokens, vocab_size, [draft_token] * num_draft_tokens + ) + + _, _, cu_num_logits, idx_mapping, seeds, pos = make_inputs( + num_draft_tokens_per_req, vocab_size + ) + sampled = sample_recovered_and_bonus_tokens( + target_probs, draft_probs, cu_num_logits, idx_mapping, seeds, pos + ) + + # Recovered positions (0, 1, 2): adjusted mass is on target_token. + for i in range(num_draft_tokens): + got = sampled[i].item() + assert got == target_token, ( + f"Recovered token at position {i}: expected {target_token}, got {got}" + ) + # Bonus position (3): samples directly from target_probs → also target_token. + assert sampled[3].item() == target_token + + +# --------------------------------------------------------------------------- +# Determinism +# --------------------------------------------------------------------------- + + +def test_deterministic_with_same_seeds(): + """Identical inputs and seeds always produce identical outputs.""" + vocab_size = 1000 + num_draft_tokens_per_req = [2, 3] + args = make_inputs(num_draft_tokens_per_req, vocab_size) + + out1 = sample_recovered_and_bonus_tokens(*args) + out2 = sample_recovered_and_bonus_tokens(*args) + assert torch.equal(out1, out2) + + +def test_different_seeds_produce_different_outputs(): + """Different seeds produce different outputs (with overwhelming probability). + + Over a batch of 16 tokens drawn from a uniform distribution, the + probability that two independent runs produce the same sequence is + negligible. + """ + vocab_size = 1000 + num_draft_tokens_per_req = [3, 3, 3, 3] + num_reqs = len(num_draft_tokens_per_req) + num_tokens = sum(n + 1 for n in num_draft_tokens_per_req) + num_draft_tokens = sum(num_draft_tokens_per_req) + + # Shared probs; only the seeds differ. + target_probs = F.softmax( + torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE), dim=-1 + ) + draft_probs = F.softmax( + torch.rand(num_draft_tokens, vocab_size, dtype=torch.float32, device=DEVICE), + dim=-1, + ) + + seeds_a = torch.zeros(num_reqs, dtype=torch.int64, device=DEVICE) + seeds_b = torch.full((num_reqs,), 99999, dtype=torch.int64, device=DEVICE) + + _, _, cu_num_logits, idx_mapping, _, pos = make_inputs( + num_draft_tokens_per_req, vocab_size + ) + + out_a = sample_recovered_and_bonus_tokens( + target_probs, draft_probs, cu_num_logits, idx_mapping, seeds_a, pos + ) + out_b = sample_recovered_and_bonus_tokens( + target_probs, draft_probs, cu_num_logits, idx_mapping, seeds_b, pos + ) + assert not torch.equal(out_a, out_b), ( + "Expected different seeds to produce different token sequences" + ) diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index a15da926da4e..842a604b69dd 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -81,6 +81,9 @@ class InputBatch: cu_num_logits: torch.Tensor cu_num_logits_np: np.ndarray + # [num_draft_tokens] + draft_logits: torch.Tensor + # Whether any requests in batch use structured output. has_structured_output_reqs: bool @@ -150,6 +153,7 @@ def make_dummy( logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, + draft_logits=torch.empty(0, 1000, device=device, dtype=torch.float32), has_structured_output_reqs=False, ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index ccab6cec8c78..f5af77149cd9 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -69,7 +69,6 @@ from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import ( set_eagle3_aux_hidden_state_layers, ) -from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker @@ -389,7 +388,7 @@ def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None: # NOTE(woosuk): During the initial memory profiling, the sampler may skip # top_k, top_p, and logprobs, using less GPU memory than what is possible # during actual execution. - self.sampler( + self.sampler.sample( logits, idx_mapping, idx_mapping_np, @@ -601,11 +600,23 @@ def prepare_inputs( expanded_local_pos = torch.zeros( num_reqs, dtype=torch.int32, device=self.device ) + draft_logits = torch.empty( + 0, self.vocab_size, device=self.device, dtype=torch.float32 + ) else: num_draft_tokens = np.array( [len(draft_tokens.get(req_id, ())) for req_id in req_ids], dtype=np.int32, ) + draft_logits = torch.concat( + [ + self.req_states.draft_logits[ + self.req_states.req_id_to_index[req_id], :num_draft_token + ] + for req_id, num_draft_token in zip(req_ids, num_draft_tokens) + ], + dim=0, + ) total_num_draft_tokens = int(num_draft_tokens.sum()) total_num_logits = num_reqs + total_num_draft_tokens @@ -750,6 +761,7 @@ def prepare_inputs( logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, + draft_logits=draft_logits, has_structured_output_reqs=scheduler_output.has_structured_output_requests, ) @@ -792,36 +804,36 @@ def sample( grammar_output.grammar_bitmask, ) - # Sample tokens and compute logprobs (if needed). - sampler_output = self.sampler( - logits, - input_batch.expanded_idx_mapping, - input_batch.idx_mapping_np, - input_batch.cu_num_logits_np, - sample_pos, - input_ids, - input_batch.expanded_local_pos, - ) - if input_batch.num_draft_tokens == 0: - # No draft tokens (common case). - num_sampled = torch.ones( - input_batch.num_reqs, dtype=torch.int32, device=self.device + # Sample tokens and compute logprobs (if needed). + sampler_output = self.sampler.sample( + logits, + input_batch.expanded_idx_mapping, + input_batch.idx_mapping_np, + input_batch.cu_num_logits_np, + sample_pos, + input_ids, + input_batch.expanded_local_pos, ) else: # Rejection sampling for spec decoding. - sampled_tokens, num_sampled = rejection_sample( - sampler_output.sampled_token_ids, - input_ids, + sampler_output = self.sampler.rejection_sample( + logits, + input_batch.draft_logits, input_batch.cu_num_logits, + input_batch.cu_num_logits_np, + input_batch.expanded_idx_mapping, + input_batch.idx_mapping_np, + sample_pos, + input_ids, + input_batch.expanded_local_pos, self.num_speculative_steps, ) - sampler_output.sampled_token_ids = sampled_tokens # Get the number of sampled and rejected tokens. # For chunked prefills, num_sampled and num_rejected are both 0. num_sampled, num_rejected = get_num_sampled_and_rejected( - num_sampled, + sampler_output.num_sampled, input_batch.seq_lens, input_batch.cu_num_logits, input_batch.idx_mapping, @@ -1090,7 +1102,7 @@ def sample_tokens( input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected ) if self.speculator is not None: - draft_tokens = self.speculator.propose( + speculation = self.speculator.propose( input_batch, hidden_states, aux_hidden_states, @@ -1101,8 +1113,15 @@ def sample_tokens( self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.seeds.gpu, ) - self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens - self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens) + self.req_states.draft_tokens[input_batch.idx_mapping] = ( + speculation.draft_tokens + ) + self.req_states.draft_logits[input_batch.idx_mapping] = ( + speculation.draft_logits + ) + self.draft_tokens_handler.set_draft_tokens( + input_batch, speculation.draft_tokens + ) if self.use_async_scheduling: return async_output diff --git a/vllm/v1/worker/gpu/sample/output.py b/vllm/v1/worker/gpu/sample/output.py index 13e8cf1d6c1e..69b26ad9f388 100644 --- a/vllm/v1/worker/gpu/sample/output.py +++ b/vllm/v1/worker/gpu/sample/output.py @@ -12,3 +12,4 @@ class SamplerOutput: sampled_token_ids: torch.Tensor logprobs_tensors: LogprobsTensors | None num_nans: torch.Tensor | None + num_sampled: torch.Tensor diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 87b10bcc11a1..9d020e1414d4 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -15,6 +15,12 @@ from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.penalties import PenaltiesState from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates +from vllm.v1.worker.gpu.spec_decode.rejection_sample import ( + rejection_sample as rejection_sample_functional, +) +from vllm.v1.worker.gpu.spec_decode.rejection_sample import ( + sample_recovered_and_bonus_tokens, +) from vllm.v1.worker.gpu.states import RequestState @@ -53,7 +59,7 @@ def apply_staged_writes(self) -> None: self.logit_bias_state.apply_staged_writes() self.bad_words_state.apply_staged_writes() - def __call__( + def sample( self, logits: torch.Tensor, idx_mapping: torch.Tensor, @@ -66,7 +72,7 @@ def __call__( # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear # that num_nans is computed before applying penalties and temperature. num_nans = get_num_nans(logits) if self.compute_nans else None - sampled, processed_logits = self.sample( + processed_logits = self._process_logits( logits, idx_mapping, idx_mapping_np, @@ -74,6 +80,14 @@ def __call__( input_ids, expanded_local_pos, ) + sampled = gumbel_sample( + logits, + idx_mapping, + self.sampling_states.temperature.gpu, + self.sampling_states.seeds.gpu, + pos, + apply_temperature=False, + ) max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) if max_num_logprobs != NO_LOGPROBS: @@ -81,12 +95,17 @@ def __call__( logits = processed_logits expanded_logits = logits.shape[0] != idx_mapping_np.shape[0] cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None + # TODO: Check if compute_topk_logprobs can handle 2d sampled logprobs_tensors = compute_topk_logprobs( logits, max_num_logprobs, sampled, cu_num_logits ) else: logprobs_tensors = None + # No draft tokens (common case). + num_reqs = len(logits) + num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=logits.device) + # These are GPU tensors. sampler_output = SamplerOutput( # The sampled tokens are expanded to 2D tensor with shape @@ -95,10 +114,73 @@ def __call__( sampled_token_ids=sampled.view(-1, 1), logprobs_tensors=logprobs_tensors, num_nans=num_nans, + num_sampled=num_sampled, ) return sampler_output - def sample( + def rejection_sample( + self, + logits: torch.Tensor, # [num_draft_tokens + num_reqs, vocab_size] + draft_logits: torch.Tensor, # [num_draft_tokens + num_reqs] + cu_num_logits: torch.Tensor, # [num_reqs + 1] + cu_num_logits_np: np.ndarray, # [num_reqs + 1] + idx_mapping: torch.Tensor, # [num_draft_tokens + num_reqs] + idx_mapping_np: np.ndarray, # [num_draft_tokens + num_reqs] + pos: torch.Tensor, # [num_draft_tokens + num_reqs] + input_ids: torch.Tensor, # [num_draft_tokens + num_reqs] + expanded_local_pos: torch.Tensor, # [num_draft_tokens + num_reqs] + num_speculative_steps: int, + ) -> SamplerOutput: + # TODO: Check whether functions expect expanded idx_mapping or not + processed_logits = self._process_logits( + logits, + idx_mapping, + idx_mapping_np, + pos, + input_ids, + expanded_local_pos, + ) + processed_probs = torch.softmax(processed_logits, dim=-1) + draft_probs = torch.softmax(draft_logits, dim=-1) + recovered_ids = sample_recovered_and_bonus_tokens( + processed_probs, + draft_probs, + cu_num_logits, + idx_mapping, + self.sampling_states.seeds.gpu, + pos, + ) + sampled, num_sampled = rejection_sample_functional( + input_ids, + recovered_ids, + processed_probs, + draft_probs, + cu_num_logits, + self.sampling_states.seeds.gpu, + pos, + idx_mapping, + num_speculative_steps, + ) + max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) + if max_num_logprobs != NO_LOGPROBS: + expanded_logits = logits.shape[0] != idx_mapping_np.shape[0] + cu_num_logits_list = cu_num_logits_np.tolist() if expanded_logits else None + logprobs_tensors = compute_topk_logprobs( + processed_logits, max_num_logprobs, sampled, cu_num_logits_list + ) + else: + logprobs_tensors = None + + # These are GPU tensors. + sampler_output = SamplerOutput( + sampled_token_ids=sampled, + logprobs_tensors=logprobs_tensors, + num_nans=None, + num_sampled=num_sampled, + ) + return sampler_output + + def _process_logits( self, logits: torch.Tensor, idx_mapping: torch.Tensor, @@ -106,7 +188,7 @@ def sample( pos: torch.Tensor, input_ids: torch.Tensor, expanded_local_pos: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: # Copy logits to a new FP32 tensor. logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) @@ -143,13 +225,4 @@ def sample( logits, idx_mapping, idx_mapping_np ) - # Sample the next token. - sampled = gumbel_sample( - logits, - idx_mapping, - self.sampling_states.temperature.gpu, - self.sampling_states.seeds.gpu, - pos, - apply_temperature=False, - ) - return sampled, logits + return logits diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 6cd13cebf995..0f7cb116c357 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -17,9 +17,10 @@ ) from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers -from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample +from vllm.v1.worker.gpu.sample.gumbel import apply_temperature, gumbel_sample from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model +from vllm.v1.worker.gpu.spec_decode.outputs import Speculation from vllm.v1.worker.utils import AttentionGroup logger = init_logger(__name__) @@ -69,6 +70,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): dtype=torch.int64, device=device, ) + self.draft_logits = torch.zeros( + self.max_num_reqs, + self.num_speculative_steps, + self.vocab_size, + dtype=torch.float32, + device=device, + ) self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device) @@ -139,19 +147,21 @@ def generate_draft( ) last_hidden_states = last_hidden_states[:num_reqs] hidden_states = hidden_states[:num_reqs] - logits = self.model.compute_logits(last_hidden_states) + draft_logits = self.model.compute_logits(last_hidden_states) + apply_temperature(draft_logits, idx_mapping, self.temperature) # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise # used for draft and target sampling. draft_tokens = gumbel_sample( - logits, + draft_logits, idx_mapping, self.temperature, self.seeds, pos + 1, - apply_temperature=True, + apply_temperature=False, ) self.draft_tokens[:num_reqs, step] = draft_tokens + self.draft_logits[:num_reqs, step] = draft_logits if step < self.num_speculative_steps - 1: # Update the inputs for the next step. @@ -198,7 +208,7 @@ def propose( temperature: torch.Tensor, # [max_num_reqs] seeds: torch.Tensor, - ) -> torch.Tensor: + ) -> Speculation: # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the # number of rejected tokens, we maintain the size of eagle's input_ids and # hidden_states the same as the target model's. This means, we pad each @@ -234,7 +244,7 @@ def propose( num_tokens_across_dp=None, # FIXME ) sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states) + draft_logits = self.model.compute_logits(sample_hidden_states) num_reqs = input_batch.num_reqs # NOTE(woosuk): For draft sampling, we only consider the temperature @@ -251,20 +261,26 @@ def propose( torch.gather(input_batch.positions, 0, last_token_indices, out=pos) # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise # used for draft and target sampling. + apply_temperature(draft_logits, idx_mapping, self.temperature) draft_tokens = gumbel_sample( - logits, + draft_logits, idx_mapping, self.temperature, self.seeds, pos + 1, - apply_temperature=True, + apply_temperature=False, ) if self.num_speculative_steps == 1: # Early exit. - return draft_tokens.view(-1, 1) + return Speculation( + draft_tokens=draft_tokens.view(-1, 1), + draft_logits=draft_logits.view(-1, 1), + ) - # Save the draft tokens for the first step. + # Save the draft tokens/probs for the first step. self.draft_tokens[:num_reqs, 0] = draft_tokens + self.draft_logits[:num_reqs, 0] = draft_logits + # Prepare the inputs for the decode steps. prepare_eagle_decode( draft_tokens, @@ -287,7 +303,10 @@ def propose( if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL: # Run full CUDA graph. self.cudagraph_manager.run_fullgraph(cudagraph_size) - return self.draft_tokens[:num_reqs] + return Speculation( + draft_tokens=self.draft_tokens[:num_reqs], + draft_logits=self.draft_logits[:num_reqs], + ) # Run eager or piecewise CUDA graph. num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs @@ -321,7 +340,10 @@ def propose( num_tokens_across_dp=None, # FIXME cudagraph_runtime_mode=cudagraph_mode, ) - return self.draft_tokens[:num_reqs] + return Speculation( + draft_tokens=self.draft_tokens[:num_reqs], + draft_logits=self.draft_logits[:num_reqs], + ) @triton.jit diff --git a/vllm/v1/worker/gpu/spec_decode/outputs.py b/vllm/v1/worker/gpu/spec_decode/outputs.py new file mode 100644 index 000000000000..07159b59e2b5 --- /dev/null +++ b/vllm/v1/worker/gpu/spec_decode/outputs.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import torch + + +@dataclass +class Speculation: + # [num_reqs, num_speculative_steps] + draft_tokens: torch.Tensor + # [num_reqs, num_speculative_steps, vocab_size] + draft_logits: torch.Tensor diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py index 8a7bf28bacbd..827473eb9273 100644 --- a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py @@ -10,62 +10,196 @@ def _rejection_sample_kernel( sampled_ptr, # [num_reqs, num_speculative_steps + 1] sampled_stride, num_sampled_ptr, # [num_reqs] - target_sampled_ptr, # [num_draft_tokens + num_reqs] input_ids_ptr, # [num_draft_tokens + num_reqs] cu_num_logits_ptr, # [num_reqs + 1] + target_probs_ptr, # [num_draft_tokens + num_reqs, vocab_size] + target_probs_stride, + draft_probs_ptr, # [num_draft_tokens, vocab_size] + draft_probs_stride, + recovered_ids_ptr, # [num_draft_tokens + num_reqs] + seeds_ptr, # [num_reqs] + pos_ptr, # [num_reqs] + idx_mapping_ptr, # [max_num_reqs] ): req_idx = tl.program_id(0) start_idx = tl.load(cu_num_logits_ptr + req_idx) end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) + req_state_idx = tl.load(idx_mapping_ptr + req_idx) + pos = tl.load(pos_ptr + req_idx) + seed = tl.load(seeds_ptr + req_state_idx) num_tokens = end_idx - start_idx num_sampled = 0 rejected = False for i in range(num_tokens - 1): if not rejected: - target_sampled = tl.load(target_sampled_ptr + start_idx + i) - draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1) - tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled) - num_sampled += 1 - if target_sampled != draft_sampled: + draft_id = tl.load(input_ids_ptr + start_idx + i + 1) + target_prob = tl.load( + target_probs_ptr + (start_idx + i) * target_probs_stride + draft_id + ) + draft_prob = tl.load( + draft_probs_ptr + + (start_idx + i - req_idx) * draft_probs_stride + + draft_id + ) + u = tl.rand(seed=seed, offset=pos + i) + if target_prob >= u * draft_prob: # Accept + token_id = draft_id + else: # Reject + token_id = tl.load(recovered_ids_ptr + start_idx + i).to(tl.int32) rejected = True + tl.store(sampled_ptr + req_idx * sampled_stride + i, token_id) + num_sampled += 1 if not rejected: - target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1) - tl.store( - sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled - ) + bonus_id = tl.load(recovered_ids_ptr + start_idx + num_tokens - 1) + tl.store(sampled_ptr + req_idx * sampled_stride + num_tokens - 1, bonus_id) num_sampled += 1 tl.store(num_sampled_ptr + req_idx, num_sampled) def rejection_sample( - # [num_draft_tokens + num_reqs] - target_sampled: torch.Tensor, # [num_draft_tokens + num_reqs] input_ids: torch.Tensor, + # [num_draft_tokens + num_reqs] + recovered_ids: torch.Tensor, + # [num_draft_tokens + num_reqs, vocab_size] + target_probs: torch.Tensor, + # [num_draft_tokens, vocab_size], + draft_probs: torch.Tensor, # [num_reqs + 1] cu_num_logits: torch.Tensor, + # [num_reqs] + seeds: torch.Tensor, + # [num_reqs] + pos: torch.Tensor, + # [max_num_reqs] + idx_mapping: torch.Tensor, num_speculative_steps: int, ) -> tuple[torch.Tensor, torch.Tensor]: num_reqs = cu_num_logits.shape[0] - 1 sampled = torch.empty( num_reqs, num_speculative_steps + 1, - dtype=target_sampled.dtype, - device=target_sampled.device, + dtype=input_ids.dtype, + device=input_ids.device, ) num_sampled = torch.empty( num_reqs, dtype=torch.int32, - device=target_sampled.device, + device=input_ids.device, ) _rejection_sample_kernel[(num_reqs,)]( sampled, sampled.stride(0), num_sampled, - target_sampled, input_ids, cu_num_logits, + target_probs, + target_probs.stride(0), + draft_probs, + draft_probs.stride(0), + recovered_ids, + seeds, + pos, + idx_mapping, num_warps=1, ) return sampled, num_sampled + + +@triton.jit +def _sample_recovered_and_bonus_tokens_kernel( + local_argmax_ptr, + local_argmax_stride, + local_max_ptr, + local_max_stride, + target_probs_ptr, + target_probs_stride, + draft_probs_ptr, + draft_probs_stride, + cu_num_logits_ptr, + seeds_ptr, + pos_ptr, + idx_mapping_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + token_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + token_idx) + end_idx = tl.load(cu_num_logits_ptr + req_state_idx + 1) + is_bonus_token_idx = token_idx == end_idx - 1 + + block_idx = tl.program_id(1) + block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + target_probs = tl.load( + target_probs_ptr + token_idx * target_probs_stride + block, + mask=mask, + other=0.0, + ) + if not is_bonus_token_idx: + draft_probs = tl.load( + draft_probs_ptr + (token_idx - req_state_idx) * draft_probs_stride + block, + mask=mask, + other=0.0, + ) + target_probs -= draft_probs + + # Calculate the seed for exponential noise. + seed = tl.load(seeds_ptr + req_state_idx) + pos = tl.load(pos_ptr + token_idx) + gumbel_seed = tl.randint(seed, pos) + # Generate exponential noise in FP32. + u = tl.rand(gumbel_seed, block) + u = tl.maximum(u, 1e-7) + exp_noise = -tl.log(u) + + value, idx = tl.max(target_probs / exp_noise, axis=0, return_indices=True) + token_id = block_idx * BLOCK_SIZE + idx + tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id) + tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value) + + +def sample_recovered_and_bonus_tokens( + target_probs: torch.Tensor, # [num_draft_tokens + num_reqs, vocab_size] + draft_probs: torch.Tensor, # [num_draft_tokens, vocab_size] + cu_num_logits: torch.Tensor, # [num_reqs + 1] + idx_mapping: torch.Tensor, # [num_draft_tokens + num_reqs] + seed: torch.Tensor, # [max_num_reqs] + pos: torch.Tensor, # [num_reqs] +) -> torch.Tensor: + num_tokens, vocab_size = target_probs.shape + BLOCK_SIZE = 1024 + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) + local_argmax = torch.empty( + num_tokens, + num_blocks, + dtype=torch.int64, + device=target_probs.device, + ) + local_max = torch.empty( + num_tokens, + num_blocks, + dtype=torch.float32, + device=target_probs.device, + ) + _sample_recovered_and_bonus_tokens_kernel[(num_tokens, num_blocks)]( + local_argmax, + local_argmax.stride(0), + local_max, + local_max.stride(0), + target_probs, + target_probs.stride(0), + draft_probs, + draft_probs.stride(0), + cu_num_logits, + seed, + pos, + idx_mapping, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + # NOTE(woosuk): Use int64 for later indexing. + max_block_idx = local_max.argmax(dim=-1, keepdim=True) + sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1) + return sampled diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index b338d32a3e39..cb9e1ad6e0c5 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -70,6 +70,13 @@ def __init__( dtype=torch.int64, device=device, ) + self.draft_logits = torch.zeros( + self.max_num_reqs, + self.num_speculative_steps, + self.vocab_size, + dtype=torch.float32, + device=device, + ) self.next_prefill_tokens = torch.zeros( self.max_num_reqs, dtype=torch.int32, device=device ) From 75b28cbd22fca16ec0ae7b094efba7c108e5b5f2 Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Thu, 26 Feb 2026 12:23:00 +0000 Subject: [PATCH 2/4] test Signed-off-by: Andy Lo --- tests/v1/spec_decode/test_rejection_sample.py | 378 +++++------------- vllm/v1/worker/gpu/input_batch.py | 3 +- vllm/v1/worker/gpu/model_runner.py | 1 + 3 files changed, 105 insertions(+), 277 deletions(-) diff --git a/tests/v1/spec_decode/test_rejection_sample.py b/tests/v1/spec_decode/test_rejection_sample.py index af7e188cddaf..f1cc402c30fa 100644 --- a/tests/v1/spec_decode/test_rejection_sample.py +++ b/tests/v1/spec_decode/test_rejection_sample.py @@ -1,299 +1,125 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for the sample_recovered_and_bonus_tokens Triton kernel. - -The kernel implements Gumbel-max sampling for speculative decoding: - - For each draft token position: sample from max(0, target_probs - draft_probs) - - For each bonus token (last per request): sample directly from target_probs -""" - import pytest import torch -import torch.nn.functional as F +from scipy import stats from vllm.platforms import current_platform from vllm.v1.worker.gpu.spec_decode.rejection_sample import ( sample_recovered_and_bonus_tokens, ) -DEVICE = current_platform.device_type - -# The Triton kernel processes the vocab in blocks of this size. -# Vocab sizes that are not multiples of this are the most likely to trigger -# the NaN bug fixed in this file (padding positions hit -inf - (-inf) = NaN). -KERNEL_BLOCK_SIZE = 1024 +device = current_platform.device_type -def make_inputs( - num_draft_tokens_per_req: list[int], - vocab_size: int, - target_probs: torch.Tensor | None = None, - draft_probs: torch.Tensor | None = None, - seeds: torch.Tensor | None = None, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - """Build the input tensors needed by sample_recovered_and_bonus_tokens. +@pytest.mark.skipif(device != "cuda", reason="Requires CUDA") +def test_sample_recovered_and_bonus_tokens_correctness(): + """Verify that sample_recovered_and_bonus_tokens produces samples whose + empirical distribution matches the theoretical distribution. - Each request contributes (num_draft + 1) token rows to target_probs: - num_draft recovered-token rows followed by one bonus-token row. - draft_probs has one row per draft token only (no bonus row). + For each non-bonus (recovered) token position the correct distribution is + max(0, target - draft) renormalized; for the bonus token position it is + target directly. We draw N samples with independent seeds and run a + chi-squared goodness-of-fit test against those expected distributions. """ - num_reqs = len(num_draft_tokens_per_req) - # Each request has num_draft draft positions + 1 bonus position. - num_tokens = sum(n + 1 for n in num_draft_tokens_per_req) - num_draft_tokens = sum(num_draft_tokens_per_req) - - # cu_num_logits[i+1] - cu_num_logits[i] = num_draft_tokens_per_req[i] + 1 - cu = [0] - for n in num_draft_tokens_per_req: - cu.append(cu[-1] + n + 1) - cu_num_logits = torch.tensor(cu, dtype=torch.int32, device=DEVICE) - - # idx_mapping[token_idx] = req_idx for every token belonging to that req. - idx_mapping = torch.zeros(num_tokens, dtype=torch.int32, device=DEVICE) - for req_idx, n in enumerate(num_draft_tokens_per_req): - start = cu[req_idx] - end = cu[req_idx + 1] - idx_mapping[start:end] = req_idx - - if seeds is None: - seeds = torch.arange(num_reqs, dtype=torch.int64, device=DEVICE) - - # pos is used as the RNG offset; distinct values give diverse random draws. - pos = torch.arange(num_tokens, dtype=torch.int32, device=DEVICE) - - if target_probs is None: - target_probs = F.softmax( - torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE), - dim=-1, - ) - if draft_probs is None: - if num_draft_tokens > 0: - draft_probs = F.softmax( - torch.rand( - num_draft_tokens, vocab_size, dtype=torch.float32, device=DEVICE - ), - dim=-1, - ) - else: - draft_probs = torch.empty(0, vocab_size, dtype=torch.float32, device=DEVICE) - - return target_probs, draft_probs, cu_num_logits, idx_mapping, seeds, pos - - -def peaked_probs(num_rows: int, vocab_size: int, token_ids: list[int]) -> torch.Tensor: - """Return near-one-hot probability distributions. - - Row i has almost all probability mass on token_ids[i], making sampling - effectively deterministic regardless of the Gumbel noise realisation. - """ - assert len(token_ids) == num_rows - probs = torch.full( - (num_rows, vocab_size), 1e-10, dtype=torch.float32, device=DEVICE + # 3 requests with 2 / 1 / 2 draft tokens each, giving 8 total token + # req 0 (2 draft + 1 bonus): tok 0, 1, 2 + # req 1 (1 draft + 1 bonus): tok 3, 4 + # req 2 (2 draft + 1 bonus): tok 5, 6, 7 + num_reqs = 3 + num_tokens = 8 + vocab_size = 3 + + target_probs = torch.tensor( + [ + [0.6, 0.3, 0.1], # req 0 recovered + [0.3, 0.4, 0.3], # req 0 recovered + [0.1, 0.7, 0.2], # req 0 bonus + [0.5, 0.3, 0.2], # req 1 recovered + [0.2, 0.5, 0.3], # req 1 bonus + [0.4, 0.4, 0.2], # req 2 recovered + [0.2, 0.6, 0.2], # req 2 recovered + [0.3, 0.3, 0.4], # req 2 bonus + ], + dtype=torch.float32, + device=device, ) - for i, t in enumerate(token_ids): - probs[i, t] = 1.0 - return probs / probs.sum(dim=-1, keepdim=True) - - -# --------------------------------------------------------------------------- -# Range / regression tests -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize( - "vocab_size", - [ - 100, - 1000, - KERNEL_BLOCK_SIZE, # aligned: last block is fully valid - KERNEL_BLOCK_SIZE + 1, # one valid token in last block - KERNEL_BLOCK_SIZE * 2 - 1, # one invalid token in last block - 128256, # Llama-3 vocab; 128256 % 1024 == 256 (256 valid tokens in last block) - 128257, # Llama-3 vocab + 1; 257 valid tokens in last block - ], -) -@pytest.mark.parametrize( - "num_draft_tokens_per_req", - [ - [1], - [3], - [1, 2, 3], - ], -) -def test_output_in_vocab_range(vocab_size: int, num_draft_tokens_per_req: list[int]): - """All sampled token IDs must satisfy 0 <= id < vocab_size. - - This is a regression test for a NaN bug: when the vocab size is not a - multiple of BLOCK_SIZE the out-of-bounds padding positions in the last - block were loaded as -inf for both target_probs and draft_probs. The - subtraction -inf - (-inf) = NaN, which then propagated through tl.max - and could select an out-of-vocab index. - """ - args = make_inputs(num_draft_tokens_per_req, vocab_size) - sampled = sample_recovered_and_bonus_tokens(*args) - - num_tokens = sum(n + 1 for n in num_draft_tokens_per_req) - assert sampled.shape == (num_tokens,) - assert (sampled >= 0).all(), f"Negative token IDs: {sampled}" - assert (sampled < vocab_size).all(), ( - f"Out-of-vocab token IDs (vocab_size={vocab_size}): max={sampled.max().item()}" + draft_probs = torch.tensor( + [ + [0.1, 0.2, 0.7], # req 0 draft 0 + [0.1, 0.2, 0.7], # req 0 draft 1 + [0.3, 0.1, 0.6], # req 1 draft 0 + [0.1, 0.3, 0.6], # req 2 draft 0 + [0.1, 0.2, 0.7], # req 2 draft 1 + ], + dtype=torch.float32, + device=device, ) - -# --------------------------------------------------------------------------- -# Correctness: bonus vs. recovered token distinction -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize("num_draft_per_req", [0, 1, 3]) -def test_bonus_token_samples_from_target_probs(num_draft_per_req: int): - """The bonus token (last position per request) samples from target_probs. - - It must NOT be adjusted by draft_probs. We verify this by concentrating - the target mass on a known token at every bonus position and checking the - output, regardless of what draft_probs contains. - """ - vocab_size = 100 - num_reqs = 4 - num_draft_tokens_per_req = [num_draft_per_req] * num_reqs - num_tokens = sum(n + 1 for n in num_draft_tokens_per_req) - num_draft_tokens = sum(num_draft_tokens_per_req) - - cu = [0] + [ - sum(n + 1 for n in num_draft_tokens_per_req[: i + 1]) for i in range(num_reqs) - ] - - # Assign a distinct expected token to the bonus position of each request. - expected_bonus = list(range(num_reqs)) # token 0, 1, 2, 3 - target_token_ids = [0] * num_tokens - for req_idx in range(num_reqs): - bonus_pos = cu[req_idx + 1] - 1 # last position of the request - target_token_ids[bonus_pos] = expected_bonus[req_idx] - - target_probs = peaked_probs(num_tokens, vocab_size, target_token_ids) - # draft_probs is random; for bonus positions it should be ignored. - draft_probs = ( - F.softmax( - torch.rand( - num_draft_tokens, vocab_size, dtype=torch.float32, device=DEVICE - ), - dim=-1, - ) - if num_draft_tokens > 0 - else torch.empty(0, vocab_size, dtype=torch.float32, device=DEVICE) + # Expected distributions (recovered = max(0, target-draft) normalised): + # tok 0: max(0, [0.5, 0.1, -0.6]) → [5/6, 1/6, 0] + # tok 1: max(0, [0.2, 0.2, -0.4]) → [1/2, 1/2, 0] + # tok 2: bonus → [0.1, 0.7, 0.2] + # tok 3: max(0, [0.2, 0.2, -0.4]) → [1/2, 1/2, 0] + # tok 4: bonus → [0.2, 0.5, 0.3] + # tok 5: max(0, [0.3, 0.1, -0.4]) → [3/4, 1/4, 0] + # tok 6: max(0, [0.1, 0.4, -0.5]) → [1/5, 4/5, 0] + # tok 7: bonus → [0.3, 0.3, 0.4] + expected_dists = torch.tensor( + [ + [5 / 6, 1 / 6, 0.0], # req 0 recovered + [0.5, 0.5, 0.0], # req 0 recovered + [0.1, 0.7, 0.2], # req 0 bonus + [0.5, 0.5, 0.0], # req 1 recovered + [0.2, 0.5, 0.3], # req 1 bonus + [3 / 4, 1 / 4, 0.0], # req 2 recovered + [1 / 5, 4 / 5, 0.0], # req 2 recovered + [0.3, 0.3, 0.4], # req 2 bonus + ], + dtype=torch.float32, + device=device, ) - _, _, cu_num_logits, idx_mapping, seeds, pos = make_inputs( - num_draft_tokens_per_req, vocab_size - ) - sampled = sample_recovered_and_bonus_tokens( - target_probs, draft_probs, cu_num_logits, idx_mapping, seeds, pos + cu_num_logits = torch.tensor([0, 3, 5, 8], dtype=torch.int32, device=device) + idx_mapping = torch.tensor( + [0, 0, 0, 1, 1, 2, 2, 2], dtype=torch.int32, device=device ) - - for req_idx in range(num_reqs): - bonus_pos = cu[req_idx + 1] - 1 - got = sampled[bonus_pos].item() - assert got == expected_bonus[req_idx], ( - f"Request {req_idx}: bonus token expected {expected_bonus[req_idx]}, " - f"got {got}" + # Exact values don't matter, only used for seeding + pos = torch.arange(num_tokens, dtype=torch.int64, device=device) + + # Draw N samples with varying seeds to build empirical distributions + N = 30_000 + counts = torch.zeros(num_tokens, vocab_size, dtype=torch.int64, device=device) + for trial in range(N): + # Give each request a distinct seed that changes every trial. + seed = torch.arange( + trial * num_reqs + 1, + (trial + 1) * num_reqs + 1, + dtype=torch.int64, + device=device, ) - - -def test_recovered_token_samples_from_target_minus_draft(): - """Recovered tokens sample from (target_probs - draft_probs). - - When target_probs ≈ one_hot(A) and draft_probs ≈ one_hot(B) with A ≠ B, - the adjusted distribution has positive mass only at A, so the sampled - token must be A. - """ - vocab_size = 100 - num_draft_tokens_per_req = [3] - num_tokens = 4 # 3 draft + 1 bonus - num_draft_tokens = 3 - - target_token = 10 - draft_token = 20 # different from target_token - - target_probs = peaked_probs(num_tokens, vocab_size, [target_token] * num_tokens) - draft_probs = peaked_probs( - num_draft_tokens, vocab_size, [draft_token] * num_draft_tokens - ) - - _, _, cu_num_logits, idx_mapping, seeds, pos = make_inputs( - num_draft_tokens_per_req, vocab_size - ) - sampled = sample_recovered_and_bonus_tokens( - target_probs, draft_probs, cu_num_logits, idx_mapping, seeds, pos - ) - - # Recovered positions (0, 1, 2): adjusted mass is on target_token. - for i in range(num_draft_tokens): - got = sampled[i].item() - assert got == target_token, ( - f"Recovered token at position {i}: expected {target_token}, got {got}" + samples = sample_recovered_and_bonus_tokens( + target_probs, draft_probs, cu_num_logits, idx_mapping, seed, pos + ) + counts.scatter_add_( + 1, + samples.unsqueeze(1), + torch.ones(num_tokens, 1, dtype=torch.int64, device=device), ) - # Bonus position (3): samples directly from target_probs → also target_token. - assert sampled[3].item() == target_token - - -# --------------------------------------------------------------------------- -# Determinism -# --------------------------------------------------------------------------- - - -def test_deterministic_with_same_seeds(): - """Identical inputs and seeds always produce identical outputs.""" - vocab_size = 1000 - num_draft_tokens_per_req = [2, 3] - args = make_inputs(num_draft_tokens_per_req, vocab_size) - - out1 = sample_recovered_and_bonus_tokens(*args) - out2 = sample_recovered_and_bonus_tokens(*args) - assert torch.equal(out1, out2) - - -def test_different_seeds_produce_different_outputs(): - """Different seeds produce different outputs (with overwhelming probability). - - Over a batch of 16 tokens drawn from a uniform distribution, the - probability that two independent runs produce the same sequence is - negligible. - """ - vocab_size = 1000 - num_draft_tokens_per_req = [3, 3, 3, 3] - num_reqs = len(num_draft_tokens_per_req) - num_tokens = sum(n + 1 for n in num_draft_tokens_per_req) - num_draft_tokens = sum(num_draft_tokens_per_req) - - # Shared probs; only the seeds differ. - target_probs = F.softmax( - torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE), dim=-1 - ) - draft_probs = F.softmax( - torch.rand(num_draft_tokens, vocab_size, dtype=torch.float32, device=DEVICE), - dim=-1, - ) - - seeds_a = torch.zeros(num_reqs, dtype=torch.int64, device=DEVICE) - seeds_b = torch.full((num_reqs,), 99999, dtype=torch.int64, device=DEVICE) - - _, _, cu_num_logits, idx_mapping, _, pos = make_inputs( - num_draft_tokens_per_req, vocab_size - ) - out_a = sample_recovered_and_bonus_tokens( - target_probs, draft_probs, cu_num_logits, idx_mapping, seeds_a, pos - ) - out_b = sample_recovered_and_bonus_tokens( - target_probs, draft_probs, cu_num_logits, idx_mapping, seeds_b, pos - ) - assert not torch.equal(out_a, out_b), ( - "Expected different seeds to produce different token sequences" - ) + # Chi-squared test to compare empirical and expected distributions + # Chi-squared test cannot handle zero frequencies, + # so we check those explicitly first + expected_freq = expected_dists * N # [num_tokens, vocab_size] + nonzero = expected_freq > 0 # [num_tokens, vocab_size] + assert torch.where(~nonzero, counts, 0).sum() == 0, "Sampled a zero-prob token" + + alpha = 1e-4 + for tok_pos in range(num_tokens): + obs = counts[tok_pos, nonzero[tok_pos]].cpu().numpy().astype(float) + exp = expected_freq[tok_pos, nonzero[tok_pos]].cpu().numpy().astype(float) + _, p_value = stats.chisquare(obs, f_exp=exp, sum_check=False) + assert p_value > alpha, ( + f"Token position {tok_pos}: empirical distribution significantly " + f"differs from expected ({p_value=:.6f})" + ) diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 842a604b69dd..055c108accbd 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -92,6 +92,7 @@ def make_dummy( cls, num_reqs: int, num_tokens: int, + vocab_size: int, input_buffers: InputBuffers, device: torch.device, ) -> "InputBatch": @@ -153,7 +154,7 @@ def make_dummy( logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, - draft_logits=torch.empty(0, 1000, device=device, dtype=torch.float32), + draft_logits=torch.empty(0, vocab_size, device=device, dtype=torch.float32), has_structured_output_reqs=False, ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index f5af77149cd9..5cfe40c0a9d4 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -947,6 +947,7 @@ def execute_model( input_batch = InputBatch.make_dummy( num_reqs=num_reqs, num_tokens=num_tokens_after_padding, + vocab_size=self.vocab_size, input_buffers=self.input_buffers, device=self.device, ) From ae0f2fec17bb41d44c3f1acaeabda7a1d9006783 Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Fri, 27 Feb 2026 00:38:23 +0000 Subject: [PATCH 3/4] Fix Signed-off-by: Andy Lo --- vllm/v1/worker/gpu/model_runner.py | 1 + vllm/v1/worker/gpu/sample/sampler.py | 10 +- .../gpu/spec_decode/eagle/speculator.py | 2 +- .../gpu/spec_decode/rejection_sample.py | 104 +++++++++++------- 4 files changed, 73 insertions(+), 44 deletions(-) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 5cfe40c0a9d4..e6f723678e8a 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -822,6 +822,7 @@ def sample( input_batch.draft_logits, input_batch.cu_num_logits, input_batch.cu_num_logits_np, + input_batch.idx_mapping, input_batch.expanded_idx_mapping, input_batch.idx_mapping_np, sample_pos, diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 9d020e1414d4..1a31e86e645a 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -124,7 +124,8 @@ def rejection_sample( draft_logits: torch.Tensor, # [num_draft_tokens + num_reqs] cu_num_logits: torch.Tensor, # [num_reqs + 1] cu_num_logits_np: np.ndarray, # [num_reqs + 1] - idx_mapping: torch.Tensor, # [num_draft_tokens + num_reqs] + idx_mapping: torch.Tensor, # [max_num_reqs] + expanded_idx_mapping: torch.Tensor, # [num_draft_tokens + num_reqs] idx_mapping_np: np.ndarray, # [num_draft_tokens + num_reqs] pos: torch.Tensor, # [num_draft_tokens + num_reqs] input_ids: torch.Tensor, # [num_draft_tokens + num_reqs] @@ -134,7 +135,7 @@ def rejection_sample( # TODO: Check whether functions expect expanded idx_mapping or not processed_logits = self._process_logits( logits, - idx_mapping, + expanded_idx_mapping, idx_mapping_np, pos, input_ids, @@ -146,7 +147,8 @@ def rejection_sample( processed_probs, draft_probs, cu_num_logits, - idx_mapping, + expanded_idx_mapping, + self.sampling_states.temperature.gpu, self.sampling_states.seeds.gpu, pos, ) @@ -158,7 +160,7 @@ def rejection_sample( cu_num_logits, self.sampling_states.seeds.gpu, pos, - idx_mapping, + self.sampling_states.temperature.gpu, num_speculative_steps, ) max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 0f7cb116c357..27da40174871 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -274,7 +274,7 @@ def propose( # Early exit. return Speculation( draft_tokens=draft_tokens.view(-1, 1), - draft_logits=draft_logits.view(-1, 1), + draft_logits=draft_logits.view(-1, 1, self.vocab_size), ) # Save the draft tokens/probs for the first step. diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py index 827473eb9273..bfde4797c8c3 100644 --- a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py @@ -16,38 +16,44 @@ def _rejection_sample_kernel( target_probs_stride, draft_probs_ptr, # [num_draft_tokens, vocab_size] draft_probs_stride, - recovered_ids_ptr, # [num_draft_tokens + num_reqs] + recovered_ids_ptr, # [num_draft_tokens + num_draft_tokens + num_reqs)] seeds_ptr, # [num_reqs] - pos_ptr, # [num_reqs] - idx_mapping_ptr, # [max_num_reqs] + pos_ptr, # [num_draft_tokens + num_reqs] + temperature_ptr, # [max_num_reqs] ): req_idx = tl.program_id(0) start_idx = tl.load(cu_num_logits_ptr + req_idx) end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) - req_state_idx = tl.load(idx_mapping_ptr + req_idx) - pos = tl.load(pos_ptr + req_idx) - seed = tl.load(seeds_ptr + req_state_idx) + seed = tl.load(seeds_ptr + req_idx) + temperature = tl.load(temperature_ptr + req_idx) num_tokens = end_idx - start_idx + is_zero_temp = temperature == 0.0 num_sampled = 0 rejected = False for i in range(num_tokens - 1): if not rejected: draft_id = tl.load(input_ids_ptr + start_idx + i + 1) - target_prob = tl.load( - target_probs_ptr + (start_idx + i) * target_probs_stride + draft_id - ) - draft_prob = tl.load( - draft_probs_ptr - + (start_idx + i - req_idx) * draft_probs_stride - + draft_id - ) - u = tl.rand(seed=seed, offset=pos + i) - if target_prob >= u * draft_prob: # Accept - token_id = draft_id - else: # Reject + if is_zero_temp: token_id = tl.load(recovered_ids_ptr + start_idx + i).to(tl.int32) - rejected = True + if token_id != draft_id: + rejected = True + else: + target_prob = tl.load( + target_probs_ptr + (start_idx + i) * target_probs_stride + draft_id + ) + draft_prob = tl.load( + draft_probs_ptr + + (start_idx + i - req_idx) * draft_probs_stride + + draft_id + ) + pos = tl.load(pos_ptr + start_idx + i) + u = tl.rand(seed=seed, offset=pos) + if target_prob >= u * draft_prob: # Accept + token_id = draft_id + else: # Reject + token_id = tl.load(recovered_ids_ptr + start_idx + i).to(tl.int32) + rejected = True tl.store(sampled_ptr + req_idx * sampled_stride + i, token_id) num_sampled += 1 if not rejected: @@ -70,10 +76,10 @@ def rejection_sample( cu_num_logits: torch.Tensor, # [num_reqs] seeds: torch.Tensor, - # [num_reqs] + # [num_draft_tokens + num_reqs] pos: torch.Tensor, # [max_num_reqs] - idx_mapping: torch.Tensor, + temperature: torch.Tensor, num_speculative_steps: int, ) -> tuple[torch.Tensor, torch.Tensor]: num_reqs = cu_num_logits.shape[0] - 1 @@ -101,7 +107,7 @@ def rejection_sample( recovered_ids, seeds, pos, - idx_mapping, + temperature, num_warps=1, ) return sampled, num_sampled @@ -118,6 +124,7 @@ def _sample_recovered_and_bonus_tokens_kernel( draft_probs_ptr, draft_probs_stride, cu_num_logits_ptr, + temperature_ptr, seeds_ptr, pos_ptr, idx_mapping_ptr, @@ -127,6 +134,9 @@ def _sample_recovered_and_bonus_tokens_kernel( token_idx = tl.program_id(0) req_state_idx = tl.load(idx_mapping_ptr + token_idx) end_idx = tl.load(cu_num_logits_ptr + req_state_idx + 1) + seed = tl.load(seeds_ptr + req_state_idx) + pos = tl.load(pos_ptr + token_idx) + temperature = tl.load(temperature_ptr + req_state_idx) is_bonus_token_idx = token_idx == end_idx - 1 block_idx = tl.program_id(1) @@ -137,24 +147,34 @@ def _sample_recovered_and_bonus_tokens_kernel( mask=mask, other=0.0, ) - if not is_bonus_token_idx: - draft_probs = tl.load( - draft_probs_ptr + (token_idx - req_state_idx) * draft_probs_stride + block, - mask=mask, - other=0.0, - ) - target_probs -= draft_probs - # Calculate the seed for exponential noise. - seed = tl.load(seeds_ptr + req_state_idx) - pos = tl.load(pos_ptr + token_idx) - gumbel_seed = tl.randint(seed, pos) - # Generate exponential noise in FP32. - u = tl.rand(gumbel_seed, block) - u = tl.maximum(u, 1e-7) - exp_noise = -tl.log(u) + if temperature == 0.0: + # Sample max of target_prons + value, idx = tl.max(target_probs, axis=0, return_indices=True) + else: + # Generate exponential noise + # Calculate the seed for exponential noise. + gumbel_seed = tl.randint(seed, pos) + u = tl.rand(gumbel_seed, block) + u = tl.maximum(u, 1e-7) + exp_noise = -tl.log(u) + + if is_bonus_token_idx: + # Sample from target_probs + value, idx = tl.max(target_probs / exp_noise, axis=0, return_indices=True) + else: + # Sample from max(target_probs - draft_probs, 0) + draft_probs = tl.load( + draft_probs_ptr + + (token_idx - req_state_idx) * draft_probs_stride + + block, + mask=mask, + other=0.0, + ) + target_probs -= draft_probs + # No need to clamp 0 because the maximum is guaranteed to be >= 0 anyway + value, idx = tl.max(target_probs / exp_noise, axis=0, return_indices=True) - value, idx = tl.max(target_probs / exp_noise, axis=0, return_indices=True) token_id = block_idx * BLOCK_SIZE + idx tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id) tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value) @@ -165,9 +185,14 @@ def sample_recovered_and_bonus_tokens( draft_probs: torch.Tensor, # [num_draft_tokens, vocab_size] cu_num_logits: torch.Tensor, # [num_reqs + 1] idx_mapping: torch.Tensor, # [num_draft_tokens + num_reqs] + temperature: torch.Tensor, # [max_num_reqs] seed: torch.Tensor, # [max_num_reqs] pos: torch.Tensor, # [num_reqs] -) -> torch.Tensor: +) -> torch.Tensor: # [num_draft_tokens + num_draft_tokens + num_reqs] + """Returned a packed tensor of: + - recovered_ids [num_draft_tokens] + - sampled_ids [num_draft_tokens + num_reqs] + """ num_tokens, vocab_size = target_probs.shape BLOCK_SIZE = 1024 num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) @@ -193,6 +218,7 @@ def sample_recovered_and_bonus_tokens( draft_probs, draft_probs.stride(0), cu_num_logits, + temperature, seed, pos, idx_mapping, From 6fd5b4c8a5ad6886968430d0bda2ed840f82da8c Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Fri, 27 Feb 2026 01:01:56 +0000 Subject: [PATCH 4/4] Clean Signed-off-by: Andy Lo --- vllm/v1/worker/gpu/sample/sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 1a31e86e645a..e79757242d6e 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -95,7 +95,6 @@ def sample( logits = processed_logits expanded_logits = logits.shape[0] != idx_mapping_np.shape[0] cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None - # TODO: Check if compute_topk_logprobs can handle 2d sampled logprobs_tensors = compute_topk_logprobs( logits, max_num_logprobs, sampled, cu_num_logits )