Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 0 additions & 66 deletions tests/v1/engine/test_logprobs_processor.py

This file was deleted.

8 changes: 0 additions & 8 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,14 +690,6 @@ def _validate_logprobs(self, model_config: ModelConfig) -> None:
parameter="logprob_token_ids",
value=n,
)
if self.logprobs is not None and self.logprobs != n:
raise VLLMValidationError(
f"When both logprobs and logprob_token_ids are set, "
f"logprobs must equal len(logprob_token_ids). Got "
f"logprobs={self.logprobs}, len(logprob_token_ids)={n}.",
parameter="logprob_token_ids",
value=n,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the validation check between logprobs and logprob_token_ids re-introduces a state where inconsistent parameters are accepted but not correctly handled by the sampler. Since the V1 sampler uses an if/else logic that prioritizes custom tokens, providing both with different lengths will lead to unexpected output formats or partially filled rows. This validation should be restored to maintain API integrity.

# Validate prompt logprobs.
if num_prompt_logprobs := self.prompt_logprobs:
Expand Down
16 changes: 7 additions & 9 deletions vllm/v1/worker/gpu/sample/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,9 @@ def compute_topk_logprobs(
# tokens where applicable.
assert logprob_token_ids_state is not None
assert expanded_idx_mapping is not None

topk_indices = None
if num_logprobs > 0:
topk_token_ids = torch.topk(logits, num_logprobs, dim=-1).indices
topk_token_ids = topk_token_ids.to(torch.int32)
else:
# This tensor just used as an int32 pointer, data not accessed.
topk_token_ids = logprob_token_ids_state.token_ids.gpu
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Reverting the cast to int32 for topk_indices re-introduces a dtype inconsistency. In the V1 engine, token IDs are generally expected to be int32 (as seen in vllm/v1/outputs.py). By keeping topk_indices as int64 (the default for torch.topk), the resulting logprob_token_ids tensor will also be int64. This can cause issues downstream in components that expect 32-bit integers. It is safer to cast the indices to match the dtype of sampled_token_ids.

Suggested change
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices.to(sampled_token_ids.dtype)


num_cols = max(num_logprobs, max_per_req_token_ids)
logprob_token_ids = sampled_token_ids.new_zeros((batch_size, 1 + num_cols))
Expand All @@ -141,8 +137,8 @@ def compute_topk_logprobs(
valid_mask,
valid_mask.stride(0),
sampled_token_ids,
topk_token_ids,
topk_token_ids.stride(0),
topk_indices if topk_indices is not None else logprob_token_ids,
topk_indices.stride(0) if topk_indices is not None else 0,
expanded_idx_mapping,
logprob_token_ids_state.num_token_ids.gpu,
logprob_token_ids_state.token_ids.gpu,
Expand Down Expand Up @@ -206,12 +202,14 @@ def _fill_logprob_token_ids_kernel(
# Override topk with per-request custom tokens.
src = per_req_token_ids_ptr + req_state_idx * per_req_token_ids_stride
valid = col < num_custom
# per_req_token_ids is int32; output is int64.
tokens = tl.load(src + col, mask=valid, other=0).to(tl.int64)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The explicit cast to tl.int64 before storing into out_token_ids_ptr is a potential source of memory corruption. In the V1 engine, sampled_token_ids (and consequently logprob_token_ids) are typically torch.int32. Storing a 64-bit integer into a 32-bit pointer in Triton will overwrite adjacent memory locations. Since compute_token_logprobs already handles the necessary cast to int64 for indexing at line 85, the kernel should store tokens using the native dtype of the output tensor to ensure memory safety. Note that the comment at line 205 should also be updated to reflect this change.

Suggested change
tokens = tl.load(src + col, mask=valid, other=0).to(tl.int64)
tokens = tl.load(src + col, mask=valid, other=0)

else:
# Fill with topk indices (no-op when NUM_TOPK == 0).
src = topk_indices_ptr + batch_idx * topk_indices_stride
valid = col < NUM_TOPK
tokens = tl.load(src + col, mask=valid, other=0)

tokens = tl.load(src + col, mask=valid, other=0).to(tl.int64)
tl.store(tid_base + col, tokens, mask=valid)
tl.store(mask_base + col, tl.full([PADDED_COLS], 1, tl.int1), mask=valid)

Expand Down
Loading