diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index ee94ea87912d..360f1c32f03b 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -57,6 +57,10 @@ EagleModelTypes, NgramGPUTypes, ] +RejectionSampleMethod = Literal[ + "strict", + "probabilistic", +] @config @@ -171,6 +175,12 @@ class SpeculativeConfig: """Load config for the draft model. If not specified, will use the load config from the target model.""" + rejection_sample_method: RejectionSampleMethod = "strict" + """Whether to use strict (target and draft sampled tokens match exactly) + or probabilistic rejection sampling. Both respect the target model + distribution, but the latter yields a higher acceptance rate at the cost + of more memory to cache draft logits.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index c4fe833ff30e..ca2aacfc35ef 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -90,7 +90,7 @@ from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import ( set_eagle3_aux_hidden_state_layers, ) -from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample +from vllm.v1.worker.gpu.spec_decode.rejection_sampler import RejectionSampler from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker @@ -162,6 +162,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.speculator = None self.num_speculative_steps = 0 self.use_aux_hidden_state_outputs = False + use_strict_rejection_sampling = False if self.speculative_config is not None: self.num_speculative_steps = self.speculative_config.num_speculative_tokens if self.is_last_pp_rank: @@ -172,6 +173,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.use_aux_hidden_state_outputs = True if self.pp_size > 1: raise ValueError("EAGLE3 with pipeline parallel is not supported.") + use_strict_rejection_sampling = ( + self.speculative_config.rejection_sample_method == "strict" + ) # Draft tokens propagation - for spec-dec + struct outputs. self.draft_tokens_handler = DraftTokensHandler(self.device) @@ -183,6 +187,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): num_speculative_steps=self.num_speculative_steps, vocab_size=self.vocab_size, device=self.device, + model_dtype=self.dtype, + cache_draft_logits=not use_strict_rejection_sampling, ) self.input_buffers = InputBuffers( max_num_reqs=self.max_num_reqs, @@ -197,6 +203,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): logprobs_mode=self.model_config.logprobs_mode, num_speculative_tokens=self.num_speculative_steps + 1, ) + self.rejection_sampler = RejectionSampler( + self.sampler, + num_speculative_steps=self.num_speculative_steps, + use_strict_rejection_sampling=use_strict_rejection_sampling, + ) self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs) # CUDA graphs. @@ -412,6 +423,7 @@ def _dummy_run( next_prefill_tokens=self.req_states.next_prefill_tokens, temperature=self.sampler.sampling_states.temperature.gpu, seeds=self.sampler.sampling_states.seeds.gpu, + draft_logits_out=self.req_states.draft_logits, num_tokens_across_dp=num_tokens_across_dp, dummy_run=True, skip_attn_for_dummy_run=skip_attn, @@ -425,24 +437,16 @@ def _dummy_run( def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None: num_reqs = hidden_states.shape[0] logits = self.model.compute_logits(hidden_states) - idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device) - idx_mapping_np = np.arange(num_reqs, dtype=np.int32) - pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device) - dummy_input_ids = torch.zeros(num_reqs, dtype=torch.int32, device=self.device) - expanded_local_pos = torch.zeros( - num_reqs, dtype=torch.int32, device=self.device + dummy_input_batch = InputBatch.make_dummy( + num_reqs, num_reqs, self.input_buffers ) + # NOTE(woosuk): During the initial memory profiling, the sampler may skip # top_k, top_p, and logprobs, using less GPU memory than what is possible # during actual execution. self.sampler( logits, - idx_mapping, - idx_mapping_np, - idx_mapping_np, - pos, - dummy_input_ids, - expanded_local_pos, + dummy_input_batch, ) @torch.inference_mode() @@ -768,8 +772,6 @@ def sample( grammar_output: GrammarOutput | None, ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]: sample_hidden_states = hidden_states[input_batch.logits_indices] - sample_pos = input_batch.positions[input_batch.logits_indices] - input_ids = input_batch.input_ids[input_batch.logits_indices] logits = self.model.compute_logits(sample_hidden_states) if grammar_output is not None: # Apply grammar bitmask to the logits in-place. @@ -780,34 +782,27 @@ def sample( grammar_output.grammar_bitmask, ) - # Sample tokens and compute logprobs (if needed). - sampler_output = self.sampler( - logits, - input_batch.expanded_idx_mapping, - input_batch.idx_mapping_np, - input_batch.cu_num_logits_np, - sample_pos, - input_ids, - input_batch.expanded_local_pos, - ) - if input_batch.num_draft_tokens == 0: # No draft tokens (common case). - num_sampled = input_batch.seq_lens.new_ones(input_batch.num_reqs) + sampler_output = self.sampler( + logits, + input_batch, + ) else: # Rejection sampling for spec decoding. - sampled_tokens, num_sampled = rejection_sample( - sampler_output.sampled_token_ids, - input_ids, - input_batch.cu_num_logits, - self.num_speculative_steps, + sampler_output = self.rejection_sampler( + logits, + input_batch, + # Draft logits are needed for probabilistic rejection sampling. + self.req_states.draft_logits[input_batch.idx_mapping] + if self.req_states.draft_logits is not None + else None, ) - sampler_output.sampled_token_ids = sampled_tokens # Get the number of sampled and rejected tokens. # For chunked prefills, num_sampled and num_rejected are both 0. num_sampled, num_rejected = get_num_sampled_and_rejected( - num_sampled, + sampler_output.num_sampled, input_batch.seq_lens, input_batch.cu_num_logits, input_batch.idx_mapping, @@ -1105,6 +1100,7 @@ def sample_tokens( self.req_states.next_prefill_tokens, self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.seeds.gpu, + self.req_states.draft_logits, num_tokens_across_dp=num_tokens_across_dp, ) self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens diff --git a/vllm/v1/worker/gpu/sample/gumbel.py b/vllm/v1/worker/gpu/sample/gumbel.py index 43be45614b19..1f10d7bb2c0b 100644 --- a/vllm/v1/worker/gpu/sample/gumbel.py +++ b/vllm/v1/worker/gpu/sample/gumbel.py @@ -55,6 +55,8 @@ def _gumbel_sample_kernel( local_argmax_stride, local_max_ptr, local_max_stride, + processed_logits_ptr, + processed_logits_stride, logits_ptr, logits_stride, expanded_idx_mapping_ptr, @@ -79,6 +81,20 @@ def _gumbel_sample_kernel( logits = logits.to(tl.float32) temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) + if (temp != 0.0) and APPLY_TEMPERATURE: + # Apply temperature. + # NOTE(woosuk): Match the behavior of _temperature_kernel. + # E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too. + logits = logits / temp + + # Store the temperature-applied logits. + if processed_logits_ptr is not None: + tl.store( + processed_logits_ptr + req_state_idx * processed_logits_stride + block, + logits, + mask=mask, + ) + if temp != 0.0: # Calculate the seed for gumbel noise. seed = tl.load(seeds_ptr + req_state_idx) @@ -90,12 +106,6 @@ def _gumbel_sample_kernel( u = tl.maximum(u, 1e-7) gumbel_noise = -tl.log(-tl.log(u)) - # Apply temperature. - if APPLY_TEMPERATURE: - # NOTE(woosuk): Match the behavior of _temperature_kernel. - # E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too. - logits = logits / temp - # Apply gumbel noise. logits = tl.where(mask, logits + gumbel_noise, float("-inf")) @@ -112,6 +122,7 @@ def gumbel_sample( seed: torch.Tensor, # [max_num_reqs] pos: torch.Tensor, # [num_tokens] apply_temperature: bool, + processed_logits_out: torch.Tensor | None = None, # [num_reqs, vocab_size] ) -> torch.Tensor: num_tokens, vocab_size = logits.shape BLOCK_SIZE = 1024 @@ -133,6 +144,8 @@ def gumbel_sample( local_argmax.stride(0), local_max, local_max.stride(0), + processed_logits_out, + processed_logits_out.stride(0) if processed_logits_out is not None else 0, logits, logits.stride(0), expanded_idx_mapping, diff --git a/vllm/v1/worker/gpu/sample/output.py b/vllm/v1/worker/gpu/sample/output.py index 13e8cf1d6c1e..f38ac8affd88 100644 --- a/vllm/v1/worker/gpu/sample/output.py +++ b/vllm/v1/worker/gpu/sample/output.py @@ -12,3 +12,4 @@ class SamplerOutput: sampled_token_ids: torch.Tensor logprobs_tensors: LogprobsTensors | None num_nans: torch.Tensor | None + num_sampled: torch.Tensor | None diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index d774c8f9b65d..ec0087d9c8b1 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -7,6 +7,7 @@ import vllm.envs as envs from vllm.config.model import LogprobsMode from vllm.sampling_params import SamplingParams +from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.sample.bad_words import BadWordsState from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample @@ -56,13 +57,15 @@ def apply_staged_writes(self) -> None: def __call__( self, logits: torch.Tensor, - expanded_idx_mapping: torch.Tensor, - idx_mapping_np: np.ndarray, - cu_num_logits_np: np.ndarray, - pos: torch.Tensor, - input_ids: torch.Tensor, - expanded_local_pos: torch.Tensor, + input_batch: InputBatch, ) -> SamplerOutput: + expanded_idx_mapping = input_batch.expanded_idx_mapping + idx_mapping_np = input_batch.idx_mapping_np + cu_num_logits_np = input_batch.cu_num_logits_np + expanded_local_pos = input_batch.expanded_local_pos + pos = input_batch.positions[input_batch.logits_indices] + input_ids = input_batch.input_ids[input_batch.logits_indices] + # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear # that num_nans is computed before applying penalties and temperature. num_nans = get_num_nans(logits) if self.compute_nans else None @@ -95,10 +98,11 @@ def __call__( sampled_token_ids=sampled.view(-1, 1), logprobs_tensors=logprobs_tensors, num_nans=num_nans, + num_sampled=input_batch.seq_lens.new_ones(input_batch.num_reqs), ) return sampler_output - def sample( + def apply_sampling_params( self, logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, @@ -106,7 +110,7 @@ def sample( pos: torch.Tensor, input_ids: torch.Tensor, expanded_local_pos: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: # Copy logits to a new FP32 tensor. logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) @@ -143,13 +147,31 @@ def sample( self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np) # Apply top_k and/or top_p. This might or might not return a new tensor. - logits = self.sampling_states.apply_top_k_top_p( + return self.sampling_states.apply_top_k_top_p( logits, expanded_idx_mapping, idx_mapping_np ) + def sample( + self, + logits: torch.Tensor, + expanded_idx_mapping: torch.Tensor, + idx_mapping_np: np.ndarray, + pos: torch.Tensor, + input_ids: torch.Tensor, + expanded_local_pos: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + processed_logits = self.apply_sampling_params( + logits, + expanded_idx_mapping, + idx_mapping_np, + pos, + input_ids, + expanded_local_pos, + ) + # Sample the next token. sampled = gumbel_sample( - logits, + processed_logits, expanded_idx_mapping, self.sampling_states.temperature.gpu, self.sampling_states.seeds.gpu, diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 8d3c3ba8e9ef..922031a52180 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -140,6 +140,7 @@ def generate_draft( slot_mappings: dict[str, torch.Tensor] | None, num_tokens_across_dp: torch.Tensor | None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + draft_logits_out: torch.Tensor | None = None, ) -> None: pos = self.input_buffers.positions[:num_reqs] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] @@ -166,6 +167,9 @@ def generate_draft( self.seeds, pos + 1, apply_temperature=True, + processed_logits_out=draft_logits_out[:, step] + if draft_logits_out is not None + else None, ) self.draft_tokens[:num_reqs, step] = draft_tokens @@ -219,6 +223,8 @@ def propose( temperature: torch.Tensor, # [max_num_reqs] seeds: torch.Tensor, + # [max_num_reqs, num_speculative_steps, vocab_size] + draft_logits_out: torch.Tensor | None, num_tokens_across_dp: torch.Tensor | None = None, dummy_run: bool = False, skip_attn_for_dummy_run: bool = False, @@ -271,6 +277,7 @@ def propose( idx_mapping.copy_(input_batch.idx_mapping) self.temperature.copy_(temperature) self.seeds.copy_(seeds) + # Gather the values and copy them to the pre-allocated buffers. pos = self.input_buffers.positions[:num_reqs] torch.gather(input_batch.positions, 0, last_token_indices, out=pos) @@ -283,7 +290,11 @@ def propose( self.seeds, pos + 1, apply_temperature=True, + processed_logits_out=draft_logits_out[:, 0] + if draft_logits_out is not None + else None, ) + if self.num_speculative_steps == 1: # Early exit. return draft_tokens.view(-1, 1) @@ -365,6 +376,7 @@ def propose( slot_mappings_updated, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=batch_desc.cg_mode, + draft_logits_out=draft_logits_out, ) return self.draft_tokens[:num_reqs] diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py deleted file mode 100644 index b542ffbd3f23..000000000000 --- a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import torch - -from vllm.triton_utils import tl, triton - - -@triton.jit -def _rejection_sample_kernel( - sampled_ptr, # [num_reqs, num_speculative_steps + 1] - sampled_stride, - num_sampled_ptr, # [num_reqs] - target_sampled_ptr, # [num_draft_tokens + num_reqs] - input_ids_ptr, # [num_draft_tokens + num_reqs] - cu_num_logits_ptr, # [num_reqs + 1] -): - req_idx = tl.program_id(0) - start_idx = tl.load(cu_num_logits_ptr + req_idx) - end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) - num_tokens = end_idx - start_idx - - num_sampled = 0 - rejected = False - for i in range(num_tokens - 1): - if not rejected: - target_sampled = tl.load(target_sampled_ptr + start_idx + i) - draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1) - tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled) - num_sampled += 1 - if target_sampled != draft_sampled: - rejected = True - if not rejected: - target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1) - tl.store( - sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled - ) - num_sampled += 1 - tl.store(num_sampled_ptr + req_idx, num_sampled) - - -def rejection_sample( - # [num_draft_tokens + num_reqs] - target_sampled: torch.Tensor, - # [num_draft_tokens + num_reqs] - input_ids: torch.Tensor, - # [num_reqs + 1] - cu_num_logits: torch.Tensor, - num_speculative_steps: int, -) -> tuple[torch.Tensor, torch.Tensor]: - num_reqs = cu_num_logits.shape[0] - 1 - sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1) - num_sampled = cu_num_logits.new_empty(num_reqs) - _rejection_sample_kernel[(num_reqs,)]( - sampled, - sampled.stride(0), - num_sampled, - target_sampled, - input_ids, - cu_num_logits, - num_warps=1, - ) - return sampled, num_sampled diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py b/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py new file mode 100644 index 000000000000..bd640dab6882 --- /dev/null +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.input_batch import InputBatch +from vllm.v1.worker.gpu.metrics.logits import get_num_nans +from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample +from vllm.v1.worker.gpu.sample.output import SamplerOutput +from vllm.v1.worker.gpu.sample.sampler import Sampler + + +@triton.jit +def _strict_rejection_sample_kernel( + sampled_ptr, # [num_reqs, num_speculative_steps + 1] + sampled_stride, + num_sampled_ptr, # [num_reqs] + target_sampled_ptr, # [num_draft_tokens + num_reqs] + input_ids_ptr, # [num_draft_tokens + num_reqs] + cu_num_logits_ptr, # [num_reqs + 1] +): + req_idx = tl.program_id(0) + start_idx = tl.load(cu_num_logits_ptr + req_idx) + end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) + num_tokens = end_idx - start_idx + + num_sampled = 0 + rejected = False + for i in range(num_tokens - 1): + if not rejected: + target_sampled = tl.load(target_sampled_ptr + start_idx + i) + draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1) + tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled) + num_sampled += 1 + if target_sampled != draft_sampled: + rejected = True + if not rejected: + target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1) + tl.store( + sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled + ) + num_sampled += 1 + tl.store(num_sampled_ptr + req_idx, num_sampled) + + +def strict_rejection_sample( + # [num_draft_tokens + num_reqs] + target_sampled: torch.Tensor, + # [num_draft_tokens + num_reqs] + draft_sampled: torch.Tensor, + # [num_reqs + 1] + cu_num_logits: torch.Tensor, + num_speculative_steps, +) -> tuple[torch.Tensor, torch.Tensor]: + num_reqs = cu_num_logits.shape[0] - 1 + sampled = torch.empty( + num_reqs, + num_speculative_steps + 1, + dtype=target_sampled.dtype, + device=target_sampled.device, + ) + num_sampled = torch.empty( + num_reqs, + dtype=torch.int32, + device=target_sampled.device, + ) + _strict_rejection_sample_kernel[(num_reqs,)]( + sampled, + sampled.stride(0), + num_sampled, + target_sampled, + draft_sampled, + cu_num_logits, + num_warps=1, + ) + return sampled, num_sampled + + +@triton.jit +def _probabilistic_rejection_sample_kernel( + # [num_reqs, num_speculative_steps + 1] + sampled_ptr, + sampled_stride, + # [num_reqs] + rejected_steps_ptr, + # [num_logits] + draft_sampled_ptr, + # [num_logits, V] + target_probs_ptr, + target_probs_stride, + # [num_reqs, num_speculative_steps, V] + draft_probs_ptr, + draft_probs_stride_0, + draft_probs_stride_1, + # [num_reqs + 1] + cu_num_logits_ptr, + # [num_logits] + pos_ptr, + # [num_reqs] + idx_mapping_ptr, + # [num_reqs] + seeds_ptr, +): + req_idx = tl.program_id(0) + start_idx = tl.load(cu_num_logits_ptr + req_idx) + num_tokens = tl.load(cu_num_logits_ptr + req_idx + 1) - start_idx + seed = tl.load(seeds_ptr + tl.load(idx_mapping_ptr + req_idx)) + + rejected_step = 0 + accepted = True + for i in range(num_tokens - 1): + if accepted: + draft_sampled = tl.load(draft_sampled_ptr + start_idx + i + 1) + target_prob = tl.load( + target_probs_ptr + (start_idx + i) * target_probs_stride + draft_sampled + ) + draft_prob = tl.load( + draft_probs_ptr + + req_idx * draft_probs_stride_0 + + i * draft_probs_stride_1 + + draft_sampled + ) + pos = tl.load(pos_ptr + start_idx + i) + u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1))) + accepted &= target_prob > u * draft_prob + tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled) + rejected_step += accepted + tl.store(rejected_steps_ptr + req_idx, rejected_step) + + +@triton.jit +def _compute_residual_logits_kernel( + # [num_reqs, V] + residual_logits_ptr, + residual_logits_stride, + # [num_reqs] + residual_pos_ptr, + # [num_logits, V] + target_logits_ptr, + target_logits_stride, + # [num_logits, V] + target_probs_ptr, + target_probs_stride, + # [num_reqs, num_speculative_steps, V] + draft_probs_ptr, + draft_probs_stride_0, + draft_probs_stride_1, + # [num_reqs] + rejected_step_ptr, + # [num_reqs + 1] + cu_num_logits_ptr, + # [num_logits] + pos_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + start_idx = tl.load(cu_num_logits_ptr + req_idx) + end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) + rejected_draft_step = tl.load(rejected_step_ptr + req_idx) + rejected_logit_idx = start_idx + rejected_draft_step + + block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = block_offsets < vocab_size + + if rejected_logit_idx < end_idx - 1: + target_probs = tl.load( + target_probs_ptr + rejected_logit_idx * target_probs_stride + block_offsets, + mask=mask, + other=0.0, + ) + draft_probs = tl.load( + draft_probs_ptr + + req_idx * draft_probs_stride_0 + + rejected_draft_step * draft_probs_stride_1 + + block_offsets, + mask=mask, + other=0.0, + ) + residual_probs = tl.maximum(target_probs - draft_probs, 0.0) + residual_logits = tl.log(residual_probs) + else: + # This is a bonus token. Directly return the target logits. + residual_logits = tl.load( + target_logits_ptr + + rejected_logit_idx * target_logits_stride + + block_offsets, + mask=mask, + other=0.0, + ) + + tl.store( + residual_logits_ptr + req_idx * residual_logits_stride + block_offsets, + residual_logits, + mask=mask, + ) + + # First block computes the residual logit positions. + if block_idx == 0: + pos_val = tl.load(pos_ptr + rejected_logit_idx) + tl.store(residual_pos_ptr + req_idx, pos_val) + + +def probabilistic_rejection_sample( + # [num_draft_tokens + num_reqs, V] + target_logits: torch.Tensor, + # [num_reqs, num_speculative_steps, V] + draft_logits: torch.Tensor, + # [num_draft_tokens + num_reqs] + draft_sampled: torch.Tensor, + # [num_reqs + 1] + cu_num_logits: torch.Tensor, + # [num_logits] + pos: torch.Tensor, + # [num_reqs] + idx_mapping: torch.Tensor, + temperature, + seeds, + num_speculative_steps, +) -> tuple[torch.Tensor, torch.Tensor]: + num_reqs = cu_num_logits.shape[0] - 1 + device = target_logits.device + vocab_size = target_logits.shape[-1] + + # Compute target and draft probs. + target_probs = torch.softmax(target_logits, dim=-1) + draft_probs = torch.softmax(draft_logits, dim=-1) + + # Rejection sample. + # [num_reqs, num_speculative_steps + 1] + sampled = torch.empty( + num_reqs, + num_speculative_steps + 1, + dtype=torch.int64, + device=device, + ) + # [num_reqs] + rejected_steps = torch.empty( + num_reqs, + dtype=torch.int64, + device=device, + ) + _probabilistic_rejection_sample_kernel[(num_reqs,)]( + sampled, + sampled.stride(0), + rejected_steps, + draft_sampled, + target_probs, + target_probs.stride(0), + draft_probs, + draft_probs.stride(0), + draft_probs.stride(1), + cu_num_logits, + pos, + idx_mapping, + seeds, + num_warps=1, + ) + + # Compute the logits and positions to resample the rejected/bonus + # tokens from. + # [num_reqs, vocab_size] + residual_logits = torch.empty( + num_reqs, + vocab_size, + dtype=target_logits.dtype, + device=device, + ) + # [num_reqs] + residual_pos = torch.empty( + num_reqs, + dtype=pos.dtype, + device=device, + ) + BLOCK_SIZE = 1024 + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) + _compute_residual_logits_kernel[(num_reqs, num_blocks)]( + residual_logits, + residual_logits.stride(0), + residual_pos, + target_logits, + target_logits.stride(0), + target_probs, + target_probs.stride(0), + draft_probs, + draft_probs.stride(0), + draft_probs.stride(1), + rejected_steps, + cu_num_logits, + pos, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Gumbel sample tokens from the residual distribution. + resampled = gumbel_sample( + residual_logits, + idx_mapping, + temperature, + seeds, + residual_pos, + apply_temperature=False, + ) + sampled.scatter_(1, rejected_steps.unsqueeze(1), resampled.unsqueeze(1)) + + return sampled, rejected_steps + 1 + + +class RejectionSampler: + def __init__( + self, + sampler: Sampler, + num_speculative_steps, + use_strict_rejection_sampling: bool = True, + ): + self.sampler = sampler + self.num_speculative_steps = num_speculative_steps + self.use_strict_rejection_sampling = use_strict_rejection_sampling + + def __call__( + self, + logits: torch.Tensor, + input_batch: InputBatch, + draft_logits: torch.Tensor | None = None, + ) -> SamplerOutput: + draft_sampled = input_batch.input_ids[input_batch.logits_indices] + # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear + # that num_nans is computed before applying penalties and temperature. + num_nans = get_num_nans(logits) if self.sampler.compute_nans else None + + if self.use_strict_rejection_sampling: + sampler_output = self.sampler( + logits, + input_batch, + ) + logprobs_tensors = sampler_output.logprobs_tensors + sampled, num_sampled = strict_rejection_sample( + sampler_output.sampled_token_ids.view(-1), + draft_sampled, + input_batch.cu_num_logits, + self.num_speculative_steps, + ) + else: + assert draft_logits is not None + pos = input_batch.positions[input_batch.logits_indices] + processed_logits = self.sampler.apply_sampling_params( + logits, + input_batch.expanded_idx_mapping, + input_batch.idx_mapping_np, + pos, + draft_sampled, + input_batch.expanded_local_pos, + ) + # TODO (TheEpicDolphin): Return logprobs for sampled token ids. + logprobs_tensors = None + sampled, num_sampled = probabilistic_rejection_sample( + processed_logits, + draft_logits, + draft_sampled, + input_batch.cu_num_logits, + pos, + input_batch.idx_mapping, + self.sampler.sampling_states.temperature.gpu, + self.sampler.sampling_states.seeds.gpu, + self.num_speculative_steps, + ) + + return SamplerOutput( + sampled_token_ids=sampled, + logprobs_tensors=logprobs_tensors, + num_nans=num_nans, + num_sampled=num_sampled, + ) diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index b338d32a3e39..fcdb1fe0bd82 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -15,6 +15,8 @@ def __init__( num_speculative_steps: int, vocab_size: int, device: torch.device, + model_dtype: torch.dtype, + cache_draft_logits: bool, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -70,6 +72,19 @@ def __init__( dtype=torch.int64, device=device, ) + # Draft token logits. + # NOTE: This tensor maintains the "processed" logits after applying temperature, + # top-p, etc. + self.draft_logits: torch.Tensor | None = None + if cache_draft_logits: + self.draft_logits = torch.zeros( + self.max_num_reqs, + self.num_speculative_steps, + self.vocab_size, + dtype=model_dtype, + device=device, + ) + self.next_prefill_tokens = torch.zeros( self.max_num_reqs, dtype=torch.int32, device=device )