Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
EagleModelTypes,
NgramGPUTypes,
]
RejectionSampleMethod = Literal["strict", "probabilistic", "synthetic"]
RejectionSampleMethod = Literal["standard", "synthetic"]
DraftSampleMethod = Literal["greedy", "gumbel"]


@config
Expand Down Expand Up @@ -183,11 +184,11 @@ class SpeculativeConfig:
"""Load config for the draft model. If not specified, will use the load
config from the target model."""

rejection_sample_method: RejectionSampleMethod = "strict"
"""Whether to use strict (target and draft sampled tokens match exactly)
or probabilistic rejection sampling. Both respect the target model
distribution, but the latter yields a higher acceptance rate at the cost
of more memory to cache draft logits."""
rejection_sample_method: RejectionSampleMethod = "standard"
"""The rejection sampling method to use. 'standard' uses probabilistic
rejection sampling (with or without cached draft logits, controlled by
draft_sample_method). 'synthetic' accepts draft tokens with a decaying
probability calibrated to synthetic_acceptance_rate."""

synthetic_acceptance_rates: list[float] | None = None
"""Per-position *unconditional* acceptance rates for synthetic rejection
Expand Down Expand Up @@ -248,6 +249,14 @@ def _resolve_synthetic_acceptance_rates(
)
return SpeculativeConfig._acceptance_length_to_rates(length, n)

draft_sample_method: DraftSampleMethod = "greedy"
"""How the draft model samples tokens. 'greedy' always picks the argmax
token, and the draft probabilities are treated as one-hot during rejection
sampling. 'gumbel' adds Gumbel noise for stochastic sampling, and the full
draft logits are used for the probability ratio test during rejection
sampling. This comes at the cost of additional GPU memory usage. This
parameter currently only applies to Model Runner V2."""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down
50 changes: 29 additions & 21 deletions vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
)

self.draft_logits: torch.Tensor | None = None
if self.speculative_config.rejection_sample_method == "probabilistic":
if self.speculative_config.draft_sample_method == "gumbel":
self.draft_logits = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
Expand Down Expand Up @@ -215,6 +215,28 @@ def run_model(
last_hidden_states, hidden_states = ret_hidden_states
return last_hidden_states, hidden_states

def _sample_draft(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
pos: torch.Tensor,
step: int,
) -> torch.Tensor:
if self.draft_logits is not None:
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
return gumbel_sample(
logits,
idx_mapping,
self.temperature,
self.seeds,
pos + 1,
apply_temperature=True,
processed_logits_out=self.draft_logits[:, step],
)
else:
return logits.argmax(dim=-1)
Comment thread
benchislett marked this conversation as resolved.

def prefill(
self,
num_reqs: int,
Expand All @@ -240,18 +262,11 @@ def prefill(
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)

# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
self.draft_tokens[:num_reqs, 0] = gumbel_sample(
self.draft_tokens[:num_reqs, 0] = self._sample_draft(
logits,
idx_mapping,
self.temperature,
self.seeds,
pos + 1,
apply_temperature=True,
processed_logits_out=self.draft_logits[:, 0]
if self.draft_logits is not None
else None,
pos,
step=0,
)
self.hidden_states[:num_reqs] = hidden_states[last_token_indices]
self.input_buffers.positions[:num_reqs] = pos
Expand Down Expand Up @@ -281,18 +296,11 @@ def generate_draft(
hidden_states = hidden_states[:num_reqs]
logits = self.model.compute_logits(last_hidden_states)

# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
draft_tokens = self._sample_draft(
logits,
idx_mapping,
self.temperature,
self.seeds,
pos + 1,
apply_temperature=True,
processed_logits_out=self.draft_logits[:, step]
if self.draft_logits is not None
else None,
pos,
step=step,
)
self.draft_tokens[:num_reqs, step] = draft_tokens

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _compute_global_lse(


@triton.jit
def _compute_block_max_and_sumexp_kernel(
def _compute_block_stats_kernel(
# [num_logits, num_blocks]
target_local_argmax_ptr,
target_local_argmax_stride,
Expand Down Expand Up @@ -77,6 +77,7 @@ def _compute_block_max_and_sumexp_kernel(
vocab_size,
num_speculative_steps,
BLOCK_SIZE: tl.constexpr,
HAS_DRAFT_LOGITS: tl.constexpr,
):
logit_idx = tl.program_id(0)
draft_step_idx = tl.load(expanded_local_pos_ptr + logit_idx)
Expand Down Expand Up @@ -112,24 +113,6 @@ def _compute_block_max_and_sumexp_kernel(
value,
)
else:
# Get local draft max and summed exponentials.
draft_logits = tl.load(
draft_logits_ptr
+ req_state_idx * draft_logits_stride_0
+ draft_step_idx * draft_logits_stride_1
+ block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
draft_max, draft_sumexp = _compute_block_max_and_sumexp(draft_logits)
tl.store(
draft_local_max_ptr + logit_idx * draft_local_max_stride + block_idx,
draft_max,
)
tl.store(
draft_local_sumexp_ptr + logit_idx * draft_local_sumexp_stride + block_idx,
draft_sumexp,
)
# Get local target max and summed exponentials.
target_logits = tl.load(
target_logits_ptr + logit_idx * target_logits_stride + block_offsets,
Expand All @@ -147,6 +130,27 @@ def _compute_block_max_and_sumexp_kernel(
+ block_idx,
target_sumexp,
)
if HAS_DRAFT_LOGITS:
# Get local draft max and summed exponentials.
draft_logits = tl.load(
draft_logits_ptr
+ req_state_idx * draft_logits_stride_0
+ draft_step_idx * draft_logits_stride_1
+ block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
draft_max, draft_sumexp = _compute_block_max_and_sumexp(draft_logits)
tl.store(
draft_local_max_ptr + logit_idx * draft_local_max_stride + block_idx,
draft_max,
)
tl.store(
draft_local_sumexp_ptr
+ logit_idx * draft_local_sumexp_stride
+ block_idx,
draft_sumexp,
)


@triton.jit
Expand Down Expand Up @@ -196,6 +200,7 @@ def _probabilistic_rejection_kernel(
pos_ptr,
vocab_num_blocks,
PADDED_VOCAB_NUM_BLOCKS: tl.constexpr,
HAS_DRAFT_LOGITS: tl.constexpr,
):
req_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
Expand Down Expand Up @@ -238,12 +243,6 @@ def _probabilistic_rejection_kernel(
target_logit = tl.load(
target_logits_ptr + logit_idx * target_logits_stride + draft_sampled
).to(tl.float32)
draft_logit = tl.load(
draft_logits_ptr
+ req_state_idx * draft_logits_stride_0
+ i * draft_logits_stride_1
+ draft_sampled
).to(tl.float32)
target_lse = _compute_global_lse(
target_local_max_ptr,
target_local_max_stride,
Expand All @@ -253,19 +252,29 @@ def _probabilistic_rejection_kernel(
vocab_num_blocks,
PADDED_VOCAB_NUM_BLOCKS,
)
draft_lse = _compute_global_lse(
draft_local_max_ptr,
draft_local_max_stride,
draft_local_sumexp_ptr,
draft_local_sumexp_stride,
logit_idx,
vocab_num_blocks,
PADDED_VOCAB_NUM_BLOCKS,
)
target_log_prob = target_logit - target_lse
draft_log_prob = draft_logit - draft_lse
pos = tl.load(pos_ptr + logit_idx)
u = tl_rand64(seed, pos, includes_zero=False)
if HAS_DRAFT_LOGITS:
draft_logit = tl.load(
draft_logits_ptr
+ req_state_idx * draft_logits_stride_0
+ i * draft_logits_stride_1
+ draft_sampled
).to(tl.float32)
draft_lse = _compute_global_lse(
draft_local_max_ptr,
draft_local_max_stride,
draft_local_sumexp_ptr,
draft_local_sumexp_stride,
logit_idx,
vocab_num_blocks,
PADDED_VOCAB_NUM_BLOCKS,
)
draft_log_prob = draft_logit - draft_lse
else:
# One-hot draft: q(draft_token) = 1, log_q = 0.
draft_log_prob = 0
# Probability ratio test: p(x) > u * q(x)
# Equivalent log form: log_p(x) > log(u) + log_q(x)
accepted &= target_log_prob > tl.log(u) + draft_log_prob
Expand Down Expand Up @@ -301,6 +310,8 @@ def _resample_kernel(
cu_num_logits_ptr,
# [num_logits]
expanded_idx_mapping_ptr,
# [num_logits]
draft_sampled_ptr,
# [max_num_reqs]
temp_ptr,
# [max_num_reqs]
Expand All @@ -309,6 +320,7 @@ def _resample_kernel(
pos_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
HAS_DRAFT_LOGITS: tl.constexpr,
):
req_idx = tl.program_id(0)
resample_idx = tl.load(rejected_step_ptr + req_idx)
Expand All @@ -327,22 +339,17 @@ def _resample_kernel(
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
target_logits = tl.load(
target_logits_ptr + resample_token_idx * target_logits_stride + block,
mask=mask,
other=float("-inf"),
).to(tl.float32)

# Compute the residual logits to resample the rejected token
# from. In the case of no rejections (bonus token), we directly
# use the target logits.
# Compute the residual logits to resample the rejected token from.
if is_bonus:
residual_logits = tl.load(
target_logits_ptr + resample_token_idx * target_logits_stride + block,
mask=mask,
other=float("-inf"),
).to(tl.float32)
else:
target_logits = tl.load(
target_logits_ptr + resample_token_idx * target_logits_stride + block,
mask=mask,
other=float("-inf"),
).to(tl.float32)
# Bonus token (no rejections). Directly use the target logits.
residual_logits = target_logits
elif HAS_DRAFT_LOGITS:
draft_logits = tl.load(
draft_logits_ptr
+ req_state_idx * draft_logits_stride_0
Expand All @@ -365,6 +372,15 @@ def _resample_kernel(
target_log_probs + tl.log(1 - ratio),
float("-inf"),
).to(tl.float32)
else:
# One-hot draft. The residual is just the target distribution with
# the rejected draft token probability zeroed out.
rejected_draft_token = tl.load(draft_sampled_ptr + resample_token_idx + 1)
residual_logits = tl.where(
block != rejected_draft_token,
target_logits,
float("-inf"),
).to(tl.float32)

# Resample the rejected/bonus token.
value, idx = gumbel_block_argmax(
Expand Down Expand Up @@ -456,7 +472,7 @@ def probabilistic_rejection_sample(
# [num_logits, V]
target_logits: torch.Tensor,
# [max_num_reqs, num_speculative_steps, V]
draft_logits: torch.Tensor,
draft_logits: torch.Tensor | None,
# [num_logits]
draft_sampled: torch.Tensor,
# [num_reqs + 1]
Expand All @@ -477,9 +493,17 @@ def probabilistic_rejection_sample(
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
num_logits, vocab_size = target_logits.shape
has_draft_logits = draft_logits is not None

# Gather draft logits, compute target argmax for greedy, and
# compute per-block LSE and max for non-greedy requests.
if draft_logits is None:
# When draft_logits is None, create a dummy tensor so that Triton
# kernel signatures receive valid pointers/strides. The kernels
# will never read from it when HAS_DRAFT_LOGITS=False.
draft_logits = target_logits.new_empty(1, 1, 1)

# Compute the block-level logits stats, such as target argmax
# (for greedy requests), and target max + softmax exponential
# (for non-greedy requests).
VOCAB_BLOCK_SIZE = 8192
vocab_num_blocks = triton.cdiv(vocab_size, VOCAB_BLOCK_SIZE)
padded_vocab_num_blocks = triton.next_power_of_2(vocab_num_blocks)
Expand All @@ -498,7 +522,7 @@ def probabilistic_rejection_sample(
draft_local_sumexp = target_logits.new_empty(
num_logits, vocab_num_blocks, dtype=torch.float32
)
_compute_block_max_and_sumexp_kernel[(num_logits, vocab_num_blocks)](
_compute_block_stats_kernel[(num_logits, vocab_num_blocks)](
target_local_argmax,
target_local_argmax.stride(0),
target_local_max,
Expand All @@ -520,6 +544,7 @@ def probabilistic_rejection_sample(
vocab_size,
num_speculative_steps,
BLOCK_SIZE=VOCAB_BLOCK_SIZE,
HAS_DRAFT_LOGITS=has_draft_logits,
)

# Sample up until the first rejected/bonus token, and store
Expand Down Expand Up @@ -559,6 +584,7 @@ def probabilistic_rejection_sample(
pos,
vocab_num_blocks,
PADDED_VOCAB_NUM_BLOCKS=padded_vocab_num_blocks,
HAS_DRAFT_LOGITS=has_draft_logits,
num_warps=1,
)

Expand Down Expand Up @@ -587,11 +613,13 @@ def probabilistic_rejection_sample(
num_sampled,
cu_num_logits,
expanded_idx_mapping,
draft_sampled,
temperature,
seed,
pos,
vocab_size,
BLOCK_SIZE=RESAMPLE_BLOCK_SIZE,
HAS_DRAFT_LOGITS=has_draft_logits,
)

# Insert the resampled tokens into the output sampled.
Expand Down
Loading
Loading