diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py b/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py index 9bcf629b8034..e1f483919b00 100644 --- a/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py @@ -3,11 +3,14 @@ import torch from vllm.triton_utils import tl, triton +from vllm.v1.outputs import LogprobsTensors from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample +from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.sampler import Sampler +from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS @triton.jit @@ -418,6 +421,26 @@ def probabilistic_rejection_sample( return sampled, rejected_steps + 1 +@triton.jit +def _flatten_sampled_kernel( + # [num_logits] + flat_sampled_ptr, + # [num_reqs, num_speculative_steps + 1] + sampled_ptr, + sampled_stride, + # [num_reqs] + num_sampled_ptr, + # [num_reqs + 1] + cu_num_logits_ptr, +): + req_idx = tl.program_id(0) + start_idx = tl.load(cu_num_logits_ptr + req_idx) + num_sampled = tl.load(num_sampled_ptr + req_idx) + for i in range(num_sampled): + token_id = tl.load(sampled_ptr + req_idx * sampled_stride + i) + tl.store(flat_sampled_ptr + start_idx + i, token_id) + + class RejectionSampler: def __init__( self, @@ -429,6 +452,40 @@ def __init__( self.num_speculative_steps = num_speculative_steps self.use_strict_rejection_sampling = use_strict_rejection_sampling + def _get_logprobs_tensors( + self, + input_batch: InputBatch, + sampled: torch.Tensor, + num_sampled: torch.Tensor, + logits: torch.Tensor, + ) -> LogprobsTensors | None: + max_num_logprobs = self.sampler.sampling_states.max_num_logprobs( + input_batch.idx_mapping_np + ) + if max_num_logprobs == NO_LOGPROBS: + return None + + num_reqs = input_batch.cu_num_logits.shape[0] - 1 + num_logits = logits.shape[0] + flat_sampled = torch.zeros( + num_logits, dtype=sampled.dtype, device=sampled.device + ) + _flatten_sampled_kernel[(num_reqs,)]( + flat_sampled, + sampled, + sampled.stride(0), + num_sampled, + input_batch.cu_num_logits, + num_warps=1, + ) + expanded_logits = num_logits != input_batch.idx_mapping.shape[0] + return compute_topk_logprobs( + logits, + max_num_logprobs, + flat_sampled, + input_batch.cu_num_logits_np.tolist() if expanded_logits else None, + ) + def __call__( self, logits: torch.Tensor, @@ -460,8 +517,6 @@ def __call__( draft_sampled, input_batch.expanded_local_pos, ) - # TODO (TheEpicDolphin): Return logprobs for sampled token ids. - logprobs_tensors = None sampled, num_sampled = probabilistic_rejection_sample( processed_logits, draft_logits, @@ -475,6 +530,14 @@ def __call__( self.sampler.sampling_states.seeds.gpu, self.num_speculative_steps, ) + logprobs_tensors = self._get_logprobs_tensors( + input_batch, + sampled, + num_sampled, + processed_logits + if self.sampler.logprobs_mode == "processed_logprobs" + else logits, + ) return SamplerOutput( sampled_token_ids=sampled,