Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@
EagleModelTypes,
NgramGPUTypes,
]
RejectionSampleMethod = Literal[
"strict",
"probabilistic",
]


@config
Expand Down Expand Up @@ -171,6 +175,12 @@ class SpeculativeConfig:
"""Load config for the draft model. If not specified, will use the load
config from the target model."""

rejection_sample_method: RejectionSampleMethod = "strict"
"""Whether to use strict (target and draft sampled tokens match exactly)
or probabilistic rejection sampling. Both respect the target model
distribution, but the latter yields a higher acceptance rate at the cost
of more memory to cache draft logits."""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down
64 changes: 30 additions & 34 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
set_eagle3_aux_hidden_state_layers,
)
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
from vllm.v1.worker.gpu.spec_decode.rejection_sampler import RejectionSampler
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
Expand Down Expand Up @@ -162,6 +162,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.speculator = None
self.num_speculative_steps = 0
self.use_aux_hidden_state_outputs = False
use_strict_rejection_sampling = False
if self.speculative_config is not None:
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
if self.is_last_pp_rank:
Expand All @@ -172,6 +173,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.use_aux_hidden_state_outputs = True
if self.pp_size > 1:
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
use_strict_rejection_sampling = (
self.speculative_config.rejection_sample_method == "strict"
)

# Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device)
Expand All @@ -183,6 +187,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size,
device=self.device,
model_dtype=self.dtype,
cache_draft_logits=not use_strict_rejection_sampling,
)
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
Expand All @@ -197,6 +203,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
self.rejection_sampler = RejectionSampler(
self.sampler,
num_speculative_steps=self.num_speculative_steps,
use_strict_rejection_sampling=use_strict_rejection_sampling,
)
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)

# CUDA graphs.
Expand Down Expand Up @@ -412,6 +423,7 @@ def _dummy_run(
next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu,
draft_logits_out=self.req_states.draft_logits,
num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
Expand All @@ -425,24 +437,16 @@ def _dummy_run(
def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
num_reqs = hidden_states.shape[0]
logits = self.model.compute_logits(hidden_states)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
dummy_input_ids = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
expanded_local_pos = torch.zeros(
num_reqs, dtype=torch.int32, device=self.device
dummy_input_batch = InputBatch.make_dummy(
num_reqs, num_reqs, self.input_buffers
)

# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self.sampler(
logits,
idx_mapping,
idx_mapping_np,
idx_mapping_np,
pos,
dummy_input_ids,
expanded_local_pos,
dummy_input_batch,
)

@torch.inference_mode()
Expand Down Expand Up @@ -768,8 +772,6 @@ def sample(
grammar_output: GrammarOutput | None,
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices]
sample_pos = input_batch.positions[input_batch.logits_indices]
input_ids = input_batch.input_ids[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
Expand All @@ -780,34 +782,27 @@ def sample(
grammar_output.grammar_bitmask,
)

# Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler(
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
input_batch.cu_num_logits_np,
sample_pos,
input_ids,
input_batch.expanded_local_pos,
)

if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
num_sampled = input_batch.seq_lens.new_ones(input_batch.num_reqs)
sampler_output = self.sampler(
logits,
input_batch,
)
else:
# Rejection sampling for spec decoding.
sampled_tokens, num_sampled = rejection_sample(
sampler_output.sampled_token_ids,
input_ids,
input_batch.cu_num_logits,
self.num_speculative_steps,
sampler_output = self.rejection_sampler(
logits,
input_batch,
# Draft logits are needed for probabilistic rejection sampling.
self.req_states.draft_logits[input_batch.idx_mapping]
if self.req_states.draft_logits is not None
else None,
)
sampler_output.sampled_token_ids = sampled_tokens

# Get the number of sampled and rejected tokens.
# For chunked prefills, num_sampled and num_rejected are both 0.
num_sampled, num_rejected = get_num_sampled_and_rejected(
num_sampled,
sampler_output.num_sampled,
input_batch.seq_lens,
input_batch.cu_num_logits,
input_batch.idx_mapping,
Expand Down Expand Up @@ -1105,6 +1100,7 @@ def sample_tokens(
self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
self.req_states.draft_logits,
num_tokens_across_dp=num_tokens_across_dp,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
Expand Down
25 changes: 19 additions & 6 deletions vllm/v1/worker/gpu/sample/gumbel.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def _gumbel_sample_kernel(
local_argmax_stride,
local_max_ptr,
local_max_stride,
processed_logits_ptr,
processed_logits_stride,
logits_ptr,
logits_stride,
expanded_idx_mapping_ptr,
Expand All @@ -79,6 +81,20 @@ def _gumbel_sample_kernel(
logits = logits.to(tl.float32)

temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if (temp != 0.0) and APPLY_TEMPERATURE:
# Apply temperature.
# NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp

# Store the temperature-applied logits.
if processed_logits_ptr is not None:
tl.store(
processed_logits_ptr + req_state_idx * processed_logits_stride + block,
logits,
mask=mask,
)

if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_state_idx)
Expand All @@ -90,12 +106,6 @@ def _gumbel_sample_kernel(
u = tl.maximum(u, 1e-7)
gumbel_noise = -tl.log(-tl.log(u))

# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp

# Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))

Expand All @@ -112,6 +122,7 @@ def gumbel_sample(
seed: torch.Tensor, # [max_num_reqs]
pos: torch.Tensor, # [num_tokens]
apply_temperature: bool,
processed_logits_out: torch.Tensor | None = None, # [num_reqs, vocab_size]
) -> torch.Tensor:
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024
Expand All @@ -133,6 +144,8 @@ def gumbel_sample(
local_argmax.stride(0),
local_max,
local_max.stride(0),
processed_logits_out,
processed_logits_out.stride(0) if processed_logits_out is not None else 0,
logits,
logits.stride(0),
expanded_idx_mapping,
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/worker/gpu/sample/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ class SamplerOutput:
sampled_token_ids: torch.Tensor
logprobs_tensors: LogprobsTensors | None
num_nans: torch.Tensor | None
num_sampled: torch.Tensor | None
42 changes: 32 additions & 10 deletions vllm/v1/worker/gpu/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import vllm.envs as envs
from vllm.config.model import LogprobsMode
from vllm.sampling_params import SamplingParams
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
Expand Down Expand Up @@ -56,13 +57,15 @@ def apply_staged_writes(self) -> None:
def __call__(
self,
logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
cu_num_logits_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
input_batch: InputBatch,
) -> SamplerOutput:
expanded_idx_mapping = input_batch.expanded_idx_mapping
idx_mapping_np = input_batch.idx_mapping_np
cu_num_logits_np = input_batch.cu_num_logits_np
expanded_local_pos = input_batch.expanded_local_pos
pos = input_batch.positions[input_batch.logits_indices]
input_ids = input_batch.input_ids[input_batch.logits_indices]

# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.compute_nans else None
Expand Down Expand Up @@ -95,18 +98,19 @@ def __call__(
sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors,
num_nans=num_nans,
num_sampled=input_batch.seq_lens.new_ones(input_batch.num_reqs),
)
return sampler_output

def sample(
def apply_sampling_params(
self,
logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)

Expand Down Expand Up @@ -143,13 +147,31 @@ def sample(
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)

# Apply top_k and/or top_p. This might or might not return a new tensor.
logits = self.sampling_states.apply_top_k_top_p(
return self.sampling_states.apply_top_k_top_p(
logits, expanded_idx_mapping, idx_mapping_np
)

def sample(
self,
logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
processed_logits = self.apply_sampling_params(
logits,
expanded_idx_mapping,
idx_mapping_np,
pos,
input_ids,
expanded_local_pos,
)

# Sample the next token.
sampled = gumbel_sample(
logits,
processed_logits,
expanded_idx_mapping,
self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu,
Expand Down
12 changes: 12 additions & 0 deletions vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def generate_draft(
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
draft_logits_out: torch.Tensor | None = None,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
Expand All @@ -166,6 +167,9 @@ def generate_draft(
self.seeds,
pos + 1,
apply_temperature=True,
processed_logits_out=draft_logits_out[:, step]
if draft_logits_out is not None
else None,
)
self.draft_tokens[:num_reqs, step] = draft_tokens

Expand Down Expand Up @@ -219,6 +223,8 @@ def propose(
temperature: torch.Tensor,
# [max_num_reqs]
seeds: torch.Tensor,
# [max_num_reqs, num_speculative_steps, vocab_size]
draft_logits_out: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
Expand Down Expand Up @@ -271,6 +277,7 @@ def propose(
idx_mapping.copy_(input_batch.idx_mapping)
self.temperature.copy_(temperature)
self.seeds.copy_(seeds)

# Gather the values and copy them to the pre-allocated buffers.
pos = self.input_buffers.positions[:num_reqs]
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
Expand All @@ -283,7 +290,11 @@ def propose(
self.seeds,
pos + 1,
apply_temperature=True,
processed_logits_out=draft_logits_out[:, 0]
if draft_logits_out is not None
else None,
)

if self.num_speculative_steps == 1:
# Early exit.
return draft_tokens.view(-1, 1)
Expand Down Expand Up @@ -365,6 +376,7 @@ def propose(
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=batch_desc.cg_mode,
draft_logits_out=draft_logits_out,
)
return self.draft_tokens[:num_reqs]

Expand Down
Loading
Loading