Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import torch

from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS


@triton.jit
Expand Down Expand Up @@ -418,6 +421,26 @@ def probabilistic_rejection_sample(
return sampled, rejected_steps + 1


@triton.jit
def _flatten_sampled_kernel(
# [num_logits]
flat_sampled_ptr,
# [num_reqs, num_speculative_steps + 1]
sampled_ptr,
sampled_stride,
# [num_reqs]
num_sampled_ptr,
# [num_reqs + 1]
cu_num_logits_ptr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
num_sampled = tl.load(num_sampled_ptr + req_idx)
for i in range(num_sampled):
token_id = tl.load(sampled_ptr + req_idx * sampled_stride + i)
tl.store(flat_sampled_ptr + start_idx + i, token_id)


class RejectionSampler:
def __init__(
self,
Expand All @@ -429,6 +452,40 @@ def __init__(
self.num_speculative_steps = num_speculative_steps
self.use_strict_rejection_sampling = use_strict_rejection_sampling

def _get_logprobs_tensors(
self,
input_batch: InputBatch,
sampled: torch.Tensor,
num_sampled: torch.Tensor,
logits: torch.Tensor,
) -> LogprobsTensors | None:
max_num_logprobs = self.sampler.sampling_states.max_num_logprobs(
input_batch.idx_mapping_np
)
if max_num_logprobs == NO_LOGPROBS:
return None

num_reqs = input_batch.cu_num_logits.shape[0] - 1
num_logits = logits.shape[0]
flat_sampled = torch.zeros(
num_logits, dtype=sampled.dtype, device=sampled.device
)
_flatten_sampled_kernel[(num_reqs,)](
flat_sampled,
sampled,
sampled.stride(0),
num_sampled,
input_batch.cu_num_logits,
num_warps=1,
)
expanded_logits = num_logits != input_batch.idx_mapping.shape[0]
return compute_topk_logprobs(
logits,
max_num_logprobs,
flat_sampled,
input_batch.cu_num_logits_np.tolist() if expanded_logits else None,
)

def __call__(
self,
logits: torch.Tensor,
Expand Down Expand Up @@ -460,8 +517,6 @@ def __call__(
draft_sampled,
input_batch.expanded_local_pos,
)
# TODO (TheEpicDolphin): Return logprobs for sampled token ids.
logprobs_tensors = None
sampled, num_sampled = probabilistic_rejection_sample(
processed_logits,
draft_logits,
Expand All @@ -475,6 +530,14 @@ def __call__(
self.sampler.sampling_states.seeds.gpu,
self.num_speculative_steps,
)
logprobs_tensors = self._get_logprobs_tensors(
input_batch,
sampled,
num_sampled,
processed_logits
if self.sampler.logprobs_mode == "processed_logprobs"
else logits,
)

return SamplerOutput(
sampled_token_ids=sampled,
Expand Down
Loading