From fa3b01ed292b1669591c40649e8f17854fb7b370 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 19 May 2026 11:01:39 -0700 Subject: [PATCH 1/2] [BugFix] Use correct logprobs for logprob_token_ids Signed-off-by: Nick Hill --- vllm/v1/sample/sampler.py | 43 +++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index d4ec2a2ddcab..d34f0d0f4bb9 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -13,7 +13,6 @@ from vllm.v1.sample.ops.logprobs import batched_count_greater_than from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler -from vllm.v1.worker.gpu.sample.logprob import compute_token_logprobs _SAMPLING_EPS = 1e-5 @@ -78,7 +77,8 @@ def forward( # This is different from the V0 sampler, which uses the logits that # is used for sampling (after penalties and temperature scaling). num_logprobs = sampling_metadata.max_num_logprobs - if num_logprobs is not None: + raw_logprobs: torch.Tensor | None = None + if num_logprobs is not None or sampling_metadata.logprob_token_ids: if logprobs_mode == "raw_logprobs": raw_logprobs = self.compute_logprobs(logits) elif logprobs_mode == "raw_logits": @@ -107,8 +107,9 @@ def forward( # This is used by generative_scoring API to get logprobs for specific tokens logprob_token_ids_tensors = None if sampling_metadata.logprob_token_ids: + assert raw_logprobs is not None logprob_token_ids_tensors = self.gather_specific_token_logprobs( - logits, sampling_metadata.logprob_token_ids, sampled + raw_logprobs, sampling_metadata.logprob_token_ids, sampled ) if num_logprobs is None: @@ -144,22 +145,19 @@ def forward( def gather_specific_token_logprobs( self, - logits: torch.Tensor, + logprobs: torch.Tensor, logprob_token_ids: dict[int, list[int]], sampled: torch.Tensor, ) -> LogprobsTensors | None: - """Compute logprobs for specific token IDs using Triton kernel. - - This method handles heterogeneous token ID lists across requests by - padding shorter lists to max length and using a fused Triton kernel - for efficient log_softmax + gather computation. + """Gather logprobs for specific token IDs requested per request. - Benchmarks show the Triton kernel approach is ~1.4x faster than sparse - gather for batch sizes > 1 due to the fused kernel reducing memory - bandwidth requirements. + Used by the generative_scoring API to return logprobs for an explicit + set of token ids rather than the top-k. Handles heterogeneous token + id lists across requests by padding shorter lists to the max length. Args: - logits: [batch_size, vocab_size] tensor of logits + logprobs: [batch_size, vocab_size] tensor of (raw) logprobs to + gather from. logprob_token_ids: dict mapping req_index -> list of token IDs sampled: [batch_size] tensor of sampled token IDs @@ -170,8 +168,8 @@ def gather_specific_token_logprobs( if not logprob_token_ids: return None - batch_size = logits.shape[0] - device = logits.device + batch_size = logprobs.shape[0] + device = logprobs.device # Find max number of tokens across all requests max_num_tokens = max(len(tids) for tids in logprob_token_ids.values()) @@ -200,19 +198,20 @@ def gather_specific_token_logprobs( # tensor so we don't need to D2H + re-upload. token_ids_tensor[:, 0] = sampled - # Compute logprobs using the fused Triton kernel (log_softmax + gather) - logprobs = compute_token_logprobs(logits, token_ids_tensor) + # Gather logprobs at the requested token ids. + gathered_logprobs = logprobs.gather(-1, token_ids_tensor) # Mask invalid (padded) positions with -inf - logprobs = logprobs.masked_fill(~valid_mask, float("-inf")) + gathered_logprobs = gathered_logprobs.masked_fill(~valid_mask, float("-inf")) - # Compute ranks for the sampled token - sampled_logits = logits.gather(-1, sampled.unsqueeze(-1)) - token_ranks = (logits > sampled_logits).sum(dim=-1) + # Compute ranks for the sampled token. log_softmax is monotonic w.r.t. + # the original logits, so ranks computed from logprobs are equivalent. + sampled_logprobs = logprobs.gather(-1, sampled.unsqueeze(-1)) + token_ranks = batched_count_greater_than(logprobs, sampled_logprobs) return LogprobsTensors( logprob_token_ids=token_ids_tensor.to(torch.int32), - logprobs=logprobs, + logprobs=gathered_logprobs, selected_token_ranks=token_ranks, ) From 83e4ebe85e33359dc788850aea4e6abff1ea7e21 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 19 May 2026 11:25:54 -0700 Subject: [PATCH 2/2] address gemini comments Signed-off-by: Nick Hill --- vllm/v1/sample/sampler.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index d34f0d0f4bb9..9ac3821a3261 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -207,6 +207,10 @@ def gather_specific_token_logprobs( # Compute ranks for the sampled token. log_softmax is monotonic w.r.t. # the original logits, so ranks computed from logprobs are equivalent. sampled_logprobs = logprobs.gather(-1, sampled.unsqueeze(-1)) + # Avoid 0/1 specialization recompile on the batch dimension of the + # compiled batched_count_greater_than. See gather_logprobs for context. + torch._dynamo.decorators.mark_unbacked(logprobs, 0) + torch._dynamo.decorators.mark_unbacked(sampled_logprobs, 0) token_ranks = batched_count_greater_than(logprobs, sampled_logprobs) return LogprobsTensors( @@ -251,7 +255,10 @@ def sample( greedy_sampled = self.greedy_sample(logits) if sampling_metadata.all_greedy: processed_logprobs = None - if sampling_metadata.max_num_logprobs is not None: + if ( + sampling_metadata.max_num_logprobs is not None + or sampling_metadata.logprob_token_ids + ): if logprobs_mode == "processed_logits": processed_logprobs = logits elif logprobs_mode == "processed_logprobs":