Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions tests/v1/spec_decode/test_rejection_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from scipy import stats

from vllm.platforms import current_platform
from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
sample_recovered_and_bonus_tokens,
)

device = current_platform.device_type


@pytest.mark.skipif(device != "cuda", reason="Requires CUDA")
def test_sample_recovered_and_bonus_tokens_correctness():
"""Verify that sample_recovered_and_bonus_tokens produces samples whose
empirical distribution matches the theoretical distribution.

For each non-bonus (recovered) token position the correct distribution is
max(0, target - draft) renormalized; for the bonus token position it is
target directly. We draw N samples with independent seeds and run a
chi-squared goodness-of-fit test against those expected distributions.
"""
# 3 requests with 2 / 1 / 2 draft tokens each, giving 8 total token
# req 0 (2 draft + 1 bonus): tok 0, 1, 2
# req 1 (1 draft + 1 bonus): tok 3, 4
# req 2 (2 draft + 1 bonus): tok 5, 6, 7
num_reqs = 3
num_tokens = 8
vocab_size = 3

target_probs = torch.tensor(
[
[0.6, 0.3, 0.1], # req 0 recovered
[0.3, 0.4, 0.3], # req 0 recovered
[0.1, 0.7, 0.2], # req 0 bonus
[0.5, 0.3, 0.2], # req 1 recovered
[0.2, 0.5, 0.3], # req 1 bonus
[0.4, 0.4, 0.2], # req 2 recovered
[0.2, 0.6, 0.2], # req 2 recovered
[0.3, 0.3, 0.4], # req 2 bonus
],
dtype=torch.float32,
device=device,
)
draft_probs = torch.tensor(
[
[0.1, 0.2, 0.7], # req 0 draft 0
[0.1, 0.2, 0.7], # req 0 draft 1
[0.3, 0.1, 0.6], # req 1 draft 0
[0.1, 0.3, 0.6], # req 2 draft 0
[0.1, 0.2, 0.7], # req 2 draft 1
],
dtype=torch.float32,
device=device,
)

# Expected distributions (recovered = max(0, target-draft) normalised):
# tok 0: max(0, [0.5, 0.1, -0.6]) → [5/6, 1/6, 0]
# tok 1: max(0, [0.2, 0.2, -0.4]) → [1/2, 1/2, 0]
# tok 2: bonus → [0.1, 0.7, 0.2]
# tok 3: max(0, [0.2, 0.2, -0.4]) → [1/2, 1/2, 0]
# tok 4: bonus → [0.2, 0.5, 0.3]
# tok 5: max(0, [0.3, 0.1, -0.4]) → [3/4, 1/4, 0]
# tok 6: max(0, [0.1, 0.4, -0.5]) → [1/5, 4/5, 0]
# tok 7: bonus → [0.3, 0.3, 0.4]
expected_dists = torch.tensor(
[
[5 / 6, 1 / 6, 0.0], # req 0 recovered
[0.5, 0.5, 0.0], # req 0 recovered
[0.1, 0.7, 0.2], # req 0 bonus
[0.5, 0.5, 0.0], # req 1 recovered
[0.2, 0.5, 0.3], # req 1 bonus
[3 / 4, 1 / 4, 0.0], # req 2 recovered
[1 / 5, 4 / 5, 0.0], # req 2 recovered
[0.3, 0.3, 0.4], # req 2 bonus
],
dtype=torch.float32,
device=device,
)

cu_num_logits = torch.tensor([0, 3, 5, 8], dtype=torch.int32, device=device)
idx_mapping = torch.tensor(
[0, 0, 0, 1, 1, 2, 2, 2], dtype=torch.int32, device=device
)
# Exact values don't matter, only used for seeding
pos = torch.arange(num_tokens, dtype=torch.int64, device=device)

# Draw N samples with varying seeds to build empirical distributions
N = 30_000
counts = torch.zeros(num_tokens, vocab_size, dtype=torch.int64, device=device)
for trial in range(N):
# Give each request a distinct seed that changes every trial.
seed = torch.arange(
trial * num_reqs + 1,
(trial + 1) * num_reqs + 1,
dtype=torch.int64,
device=device,
)
samples = sample_recovered_and_bonus_tokens(
target_probs, draft_probs, cu_num_logits, idx_mapping, seed, pos
)
counts.scatter_add_(
1,
samples.unsqueeze(1),
torch.ones(num_tokens, 1, dtype=torch.int64, device=device),
)

# Chi-squared test to compare empirical and expected distributions
# Chi-squared test cannot handle zero frequencies,
# so we check those explicitly first
expected_freq = expected_dists * N # [num_tokens, vocab_size]
nonzero = expected_freq > 0 # [num_tokens, vocab_size]
assert torch.where(~nonzero, counts, 0).sum() == 0, "Sampled a zero-prob token"

alpha = 1e-4
for tok_pos in range(num_tokens):
obs = counts[tok_pos, nonzero[tok_pos]].cpu().numpy().astype(float)
exp = expected_freq[tok_pos, nonzero[tok_pos]].cpu().numpy().astype(float)
_, p_value = stats.chisquare(obs, f_exp=exp, sum_check=False)
assert p_value > alpha, (
f"Token position {tok_pos}: empirical distribution significantly "
f"differs from expected ({p_value=:.6f})"
)
5 changes: 5 additions & 0 deletions vllm/v1/worker/gpu/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class InputBatch:
cu_num_logits: torch.Tensor
cu_num_logits_np: np.ndarray

# [num_draft_tokens]
draft_logits: torch.Tensor

# Whether any requests in batch use structured output.
has_structured_output_reqs: bool

Expand All @@ -89,6 +92,7 @@ def make_dummy(
cls,
num_reqs: int,
num_tokens: int,
vocab_size: int,
input_buffers: InputBuffers,
device: torch.device,
) -> "InputBatch":
Expand Down Expand Up @@ -150,6 +154,7 @@ def make_dummy(
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
draft_logits=torch.empty(0, vocab_size, device=device, dtype=torch.float32),
has_structured_output_reqs=False,
)

Expand Down
69 changes: 45 additions & 24 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
set_eagle3_aux_hidden_state_layers,
)
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
Expand Down Expand Up @@ -389,7 +388,7 @@ def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self.sampler(
self.sampler.sample(
logits,
idx_mapping,
idx_mapping_np,
Expand Down Expand Up @@ -601,11 +600,23 @@ def prepare_inputs(
expanded_local_pos = torch.zeros(
num_reqs, dtype=torch.int32, device=self.device
)
draft_logits = torch.empty(
0, self.vocab_size, device=self.device, dtype=torch.float32
)
else:
num_draft_tokens = np.array(
[len(draft_tokens.get(req_id, ())) for req_id in req_ids],
dtype=np.int32,
)
draft_logits = torch.concat(
[
self.req_states.draft_logits[
self.req_states.req_id_to_index[req_id], :num_draft_token
]
for req_id, num_draft_token in zip(req_ids, num_draft_tokens)
],
dim=0,
)
total_num_draft_tokens = int(num_draft_tokens.sum())
total_num_logits = num_reqs + total_num_draft_tokens

Expand Down Expand Up @@ -750,6 +761,7 @@ def prepare_inputs(
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
draft_logits=draft_logits,
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
)

Expand Down Expand Up @@ -792,36 +804,37 @@ def sample(
grammar_output.grammar_bitmask,
)

# Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler(
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
input_batch.cu_num_logits_np,
sample_pos,
input_ids,
input_batch.expanded_local_pos,
)

if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
num_sampled = torch.ones(
input_batch.num_reqs, dtype=torch.int32, device=self.device
# Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler.sample(
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
input_batch.cu_num_logits_np,
sample_pos,
input_ids,
input_batch.expanded_local_pos,
)
else:
# Rejection sampling for spec decoding.
sampled_tokens, num_sampled = rejection_sample(
sampler_output.sampled_token_ids,
input_ids,
sampler_output = self.sampler.rejection_sample(
logits,
input_batch.draft_logits,
input_batch.cu_num_logits,
input_batch.cu_num_logits_np,
input_batch.idx_mapping,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
sample_pos,
input_ids,
input_batch.expanded_local_pos,
self.num_speculative_steps,
)
sampler_output.sampled_token_ids = sampled_tokens

# Get the number of sampled and rejected tokens.
# For chunked prefills, num_sampled and num_rejected are both 0.
num_sampled, num_rejected = get_num_sampled_and_rejected(
num_sampled,
sampler_output.num_sampled,
input_batch.seq_lens,
input_batch.cu_num_logits,
input_batch.idx_mapping,
Expand Down Expand Up @@ -935,6 +948,7 @@ def execute_model(
input_batch = InputBatch.make_dummy(
num_reqs=num_reqs,
num_tokens=num_tokens_after_padding,
vocab_size=self.vocab_size,
input_buffers=self.input_buffers,
device=self.device,
)
Expand Down Expand Up @@ -1090,7 +1104,7 @@ def sample_tokens(
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
)
if self.speculator is not None:
draft_tokens = self.speculator.propose(
speculation = self.speculator.propose(
input_batch,
hidden_states,
aux_hidden_states,
Expand All @@ -1101,8 +1115,15 @@ def sample_tokens(
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
self.req_states.draft_tokens[input_batch.idx_mapping] = (
speculation.draft_tokens
)
self.req_states.draft_logits[input_batch.idx_mapping] = (
speculation.draft_logits
)
self.draft_tokens_handler.set_draft_tokens(
input_batch, speculation.draft_tokens
)

if self.use_async_scheduling:
return async_output
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/worker/gpu/sample/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ class SamplerOutput:
sampled_token_ids: torch.Tensor
logprobs_tensors: LogprobsTensors | None
num_nans: torch.Tensor | None
num_sampled: torch.Tensor
Loading