Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tests/v1/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,6 @@ def native_sample_recovered_tokens(
# state because RNG state advances after each call.
generator.set_state(states[i])

inv_q = q.reciprocal()

out = torch.empty_like(draft_token_ids)

for req_idx in range(batch_size):
Expand All @@ -593,7 +591,7 @@ def native_sample_recovered_tokens(
0.0
)

score = prob * inv_q[req_idx]
score = prob / q[req_idx]
recovered_id = torch.argmax(score, dim=-1)
out[token_idx] = recovered_id
return out
Expand Down
200 changes: 151 additions & 49 deletions vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def apply_logits_processors(
predict_bonus_token=False,
spec_token_ids=sampling_metadata.spec_token_ids,
)

return logits

@staticmethod
Expand Down Expand Up @@ -465,23 +466,28 @@ def rejection_sample(
if sampling_metadata.all_greedy:
return output_token_ids

# Compute probability distribution from target logits.
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
assert target_probs.is_contiguous()

# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
# Precompute one random threshold per speculative token for inverse-CDF
# recovery. Unlike the eager Gumbel/exponential race, this avoids
# generating batch_size * vocab_size random numbers up front; the lazy
# kernel only scans the vocabulary if a rejection actually occurs.
#
# This is distribution-equivalent to sampling from
# (target_prob - draft_prob)^+ but does not preserve the exact random
# stream used by the previous exponential-race implementation.
recovery_uniform_probs = generate_uniform_probs(
num_tokens,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
sampling_metadata.generators,
device,
)

# Triton tile size for vocab reduction during lazy recovery.
# Kept large to reduce loop iterations but still fit in SRAM.
BLOCK_SIZE: tl.constexpr = 8192

# Rejection sampling for random sampling requests.
assert uniform_probs is not None
rejection_random_sample_kernel[(batch_size,)](
Expand All @@ -491,12 +497,13 @@ def rejection_sample(
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
recovery_uniform_probs,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
synthetic_conditional_rates,
BLOCK_SIZE,
NO_DRAFT_PROBS=draft_probs is None,
SYNTHETIC_MODE=synthetic_mode,
)
Expand Down Expand Up @@ -637,15 +644,14 @@ def generate_uniform_probs(
# uniform_prob is sampled to be exact 0.0 as reported in
# https://github.com/pytorch/pytorch/issues/16706. Using float64
# mitigates the issue.
uniform_probs = torch.rand(
uniform_probs = torch.empty(
(num_tokens,),
dtype=torch.float64,
device=device,
)
uniform_probs.uniform_()
Comment thread
masterFoad marked this conversation as resolved.
start_idx = 0
for req_idx, n in enumerate(num_draft_tokens):
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if n == 0:
continue
end_idx = start_idx + n
Expand All @@ -656,6 +662,28 @@ def generate_uniform_probs(
return uniform_probs


def generate_recovery_noise(
batch_size: int,
vocab_size: int,
num_draft_tokens: list[int],
generators: dict[int, torch.Generator],
device: torch.device,
) -> torch.Tensor:
# NOTE(woosuk): Create only one distribution for each request.
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
device=device,
)
q.exponential_()
for i, generator in generators.items():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)
return q


def sample_recovered_tokens(
max_spec_len: int,
num_draft_tokens: list[int],
Expand All @@ -670,22 +698,22 @@ def sample_recovered_tokens(
sampling_metadata: SamplingMetadata,
device: torch.device,
) -> torch.Tensor:
# NOTE(woosuk): Create only one distribution for each request.
"""
Compatibility helper for tests and callers that need eager recovery.

The production rejection path computes recovery lazily inside
rejection_random_sample_kernel, but this helper remains useful for
correctness tests against the previous eager implementation.
"""
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
device=device,
q = generate_recovery_noise(
batch_size,
vocab_size,
num_draft_tokens,
sampling_metadata.generators,
device,
)
q.exponential_()
for i, generator in sampling_metadata.generators.items():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)

inv_q = q.reciprocal()

recovered_token_ids = torch.empty_like(draft_token_ids)
BLOCK_SIZE = 8192
Expand All @@ -695,7 +723,7 @@ def sample_recovered_tokens(
draft_token_ids,
draft_probs,
target_probs,
inv_q,
q,
vocab_size,
BLOCK_SIZE,
NO_DRAFT_PROBS=draft_probs is None,
Expand Down Expand Up @@ -766,12 +794,13 @@ def rejection_random_sample_kernel(
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
recovery_uniform_probs_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
synthetic_conditional_rates_ptr, # [num_speculative_tokens] or None
BLOCK_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
SYNTHETIC_MODE: tl.constexpr,
):
Expand All @@ -788,34 +817,107 @@ def rejection_random_sample_kernel(
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
token_idx = start_idx + pos
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
uniform_prob = tl.load(uniform_probs_ptr + token_idx)
if SYNTHETIC_MODE:
rate = tl.load(synthetic_conditional_rates_ptr + pos)
accepted = uniform_prob < rate
else:
if NO_DRAFT_PROBS:
draft_prob = 1
draft_prob = 1.0
else:
draft_prob = tl.load(
draft_probs_ptr
+ (start_idx + pos) * vocab_size
+ draft_token_id
draft_probs_ptr + token_idx * vocab_size + draft_token_id
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
target_probs_ptr + token_idx * vocab_size + draft_token_id
)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
accepted = draft_prob > 0 and target_prob / draft_prob >= uniform_prob

if accepted:
token_id = draft_token_id
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
draft_token_id,
)
else:
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id
)

# Lazy recovery: compute recovered token only after the first
# rejection. We sample from (target_prob - draft_prob)^+ using
# inverse CDF, which needs only one random number per draft
# token instead of one exponential race value per vocab entry.
if NO_DRAFT_PROBS:
draft_target_prob = tl.load(
target_probs_ptr + token_idx * vocab_size + draft_token_id
)
total_prob = 1.0 - draft_target_prob
else:
total_prob = 0.0
for v in range(0, vocab_size, BLOCK_SIZE):
vocab_offset = v + tl.arange(0, BLOCK_SIZE)
vocab_mask = vocab_offset < vocab_size
draft_prob_v = tl.load(
draft_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=vocab_mask,
other=0.0,
)
target_prob_v = tl.load(
target_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=vocab_mask,
other=0.0,
)
prob = tl.maximum(target_prob_v - draft_prob_v, 0.0)

total_prob += tl.sum(prob, axis=0)

threshold = (
tl.load(recovery_uniform_probs_ptr + token_idx) * total_prob
)
cumulative_prob = 0.0
recovered_id = 0
found = False
for v in range(0, vocab_size, BLOCK_SIZE):
vocab_offset = v + tl.arange(0, BLOCK_SIZE)
vocab_mask = vocab_offset < vocab_size

if NO_DRAFT_PROBS:
prob = tl.load(
target_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=(vocab_mask & (vocab_offset != draft_token_id)),
other=0.0,
)
else:
draft_prob_v = tl.load(
draft_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=vocab_mask,
other=0.0,
)
target_prob_v = tl.load(
target_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=vocab_mask,
other=0.0,
)
prob = tl.maximum(target_prob_v - draft_prob_v, 0.0)

prefix_prob = cumulative_prob + tl.cumsum(prob, axis=0)
candidates = tl.where(
(prefix_prob > threshold) & vocab_mask,
vocab_offset,
vocab_size,
)
candidate_id = tl.min(candidates, axis=0)
if (not found) and candidate_id < vocab_size:
recovered_id = candidate_id
found = True
cumulative_prob += tl.sum(prob, axis=0)

tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
recovered_id,
)

if not rejected:
# If all tokens are accepted, append the bonus token.
Expand Down Expand Up @@ -857,7 +959,7 @@ def sample_recovered_tokens_kernel(
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
inv_q_ptr, # [batch_size, vocab_size]
q_ptr, # [batch_size, vocab_size]
vocab_size,
BLOCK_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
Expand Down Expand Up @@ -901,17 +1003,17 @@ def sample_recovered_tokens_kernel(
other=0.0,
)
prob = tl.maximum(target_prob - draft_prob, 0.0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here
# because `tl.argmax` will select the maximum value.

inv_q = tl.load(
inv_q_ptr + req_idx * vocab_size + vocab_offset,
q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_mask,
other=0.0,
other=1.0,
)

# Local tile reduction
score = prob * inv_q
# Local tile reduction.
score = prob / q
local_max, local_id = tl.max(score, axis=0, return_indices=True)

if local_max > max_val:
Expand Down
Loading