Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.worker.gpu.sample.logprob import compute_token_logprobs

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -78,7 +77,8 @@ def forward(
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs: torch.Tensor | None = None
if num_logprobs is not None or sampling_metadata.logprob_token_ids:
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif logprobs_mode == "raw_logits":
Expand Down Expand Up @@ -107,8 +107,9 @@ def forward(
# This is used by generative_scoring API to get logprobs for specific tokens
logprob_token_ids_tensors = None
if sampling_metadata.logprob_token_ids:
assert raw_logprobs is not None
logprob_token_ids_tensors = self.gather_specific_token_logprobs(
logits, sampling_metadata.logprob_token_ids, sampled
raw_logprobs, sampling_metadata.logprob_token_ids, sampled
)
Comment thread
njhill marked this conversation as resolved.

if num_logprobs is None:
Expand Down Expand Up @@ -144,22 +145,19 @@ def forward(

def gather_specific_token_logprobs(
self,
logits: torch.Tensor,
logprobs: torch.Tensor,
logprob_token_ids: dict[int, list[int]],
sampled: torch.Tensor,
) -> LogprobsTensors | None:
"""Compute logprobs for specific token IDs using Triton kernel.
"""Gather logprobs for specific token IDs requested per request.

This method handles heterogeneous token ID lists across requests by
padding shorter lists to max length and using a fused Triton kernel
for efficient log_softmax + gather computation.

Benchmarks show the Triton kernel approach is ~1.4x faster than sparse
gather for batch sizes > 1 due to the fused kernel reducing memory
bandwidth requirements.
Used by the generative_scoring API to return logprobs for an explicit
set of token ids rather than the top-k. Handles heterogeneous token
id lists across requests by padding shorter lists to the max length.

Args:
logits: [batch_size, vocab_size] tensor of logits
logprobs: [batch_size, vocab_size] tensor of (raw) logprobs to
gather from.
logprob_token_ids: dict mapping req_index -> list of token IDs
sampled: [batch_size] tensor of sampled token IDs

Expand All @@ -170,8 +168,8 @@ def gather_specific_token_logprobs(
if not logprob_token_ids:
return None

batch_size = logits.shape[0]
device = logits.device
batch_size = logprobs.shape[0]
device = logprobs.device

# Find max number of tokens across all requests
max_num_tokens = max(len(tids) for tids in logprob_token_ids.values())
Expand Down Expand Up @@ -200,19 +198,24 @@ def gather_specific_token_logprobs(
# tensor so we don't need to D2H + re-upload.
token_ids_tensor[:, 0] = sampled

# Compute logprobs using the fused Triton kernel (log_softmax + gather)
logprobs = compute_token_logprobs(logits, token_ids_tensor)
# Gather logprobs at the requested token ids.
gathered_logprobs = logprobs.gather(-1, token_ids_tensor)

# Mask invalid (padded) positions with -inf
logprobs = logprobs.masked_fill(~valid_mask, float("-inf"))
gathered_logprobs = gathered_logprobs.masked_fill(~valid_mask, float("-inf"))

# Compute ranks for the sampled token
sampled_logits = logits.gather(-1, sampled.unsqueeze(-1))
token_ranks = (logits > sampled_logits).sum(dim=-1)
# Compute ranks for the sampled token. log_softmax is monotonic w.r.t.
# the original logits, so ranks computed from logprobs are equivalent.
sampled_logprobs = logprobs.gather(-1, sampled.unsqueeze(-1))
# Avoid 0/1 specialization recompile on the batch dimension of the
# compiled batched_count_greater_than. See gather_logprobs for context.
torch._dynamo.decorators.mark_unbacked(logprobs, 0)
torch._dynamo.decorators.mark_unbacked(sampled_logprobs, 0)
token_ranks = batched_count_greater_than(logprobs, sampled_logprobs)
Comment thread
njhill marked this conversation as resolved.

return LogprobsTensors(
logprob_token_ids=token_ids_tensor.to(torch.int32),
logprobs=logprobs,
logprobs=gathered_logprobs,
selected_token_ranks=token_ranks,
)

Expand Down Expand Up @@ -252,7 +255,10 @@ def sample(
greedy_sampled = self.greedy_sample(logits)
if sampling_metadata.all_greedy:
processed_logprobs = None
if sampling_metadata.max_num_logprobs is not None:
if (
sampling_metadata.max_num_logprobs is not None
or sampling_metadata.logprob_token_ids
):
if logprobs_mode == "processed_logits":
processed_logprobs = logits
elif logprobs_mode == "processed_logprobs":
Expand Down
Loading