diff --git a/tests/v1/engine/test_logprobs_processor.py b/tests/v1/engine/test_logprobs_processor.py deleted file mode 100644 index edb8cef518ca..000000000000 --- a/tests/v1/engine/test_logprobs_processor.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Unit tests for LogprobsProcessor. - -These tests exercise the truncation invariant that the MRV2 sampler relies -on: when the sampler returns a row wider than a request's own -`num_logprobs + 1` (because another request in the batch needed a wider -row), the trailing positions are populated with sentinel values -(`token_id=0`, `logprob=-inf`). LogprobsProcessor must read only the first -`num_logprobs + 1` entries so those sentinels never reach the user. -""" - -import numpy as np - -from vllm.logprobs import create_sample_logprobs -from vllm.v1.engine.logprobs import LogprobsProcessor -from vllm.v1.outputs import LogprobsLists - - -def _make_processor(num_logprobs: int) -> LogprobsProcessor: - return LogprobsProcessor( - tokenizer=None, - logprobs=create_sample_logprobs(flat_logprobs=False), - prompt_logprobs=None, - cumulative_logprob=0.0, - num_logprobs=num_logprobs, - num_prompt_logprobs=None, - ) - - -def test_drops_trailing_sentinel_columns(): - """A request that asked for 3 custom token logprobs but ended up in a - batch padded to width 5 must not surface the trailing -inf entries.""" - processor = _make_processor(num_logprobs=3) - - sampled = 42 - # Layout: [sampled, custom_1, custom_2, custom_3, SENTINEL, SENTINEL] - # Use float32-exact values so cumulative_logprob compares cleanly. - token_ids = np.array([[sampled, 100, 200, 300, 0, 0]], dtype=np.int32) - logprobs = np.array([[-0.5, -1.0, -2.0, -3.0, -np.inf, -np.inf]], dtype=np.float32) - ranks = np.array([1], dtype=np.int32) - - processor._update_sample_logprobs(LogprobsLists(token_ids, logprobs, ranks)) - - assert len(processor.logprobs) == 1 - pos = processor.logprobs[0] - # Exactly sampled + 3 requested tokens; trailing sentinels dropped. - assert set(pos.keys()) == {sampled, 100, 200, 300} - assert 0 not in pos - assert all(np.isfinite(lp.logprob) for lp in pos.values()) - # cumulative_logprob comes from the sampled token's logprob only. - assert processor.cumulative_logprob == -0.5 - - -def test_accepts_exactly_sized_row(): - """When the row is exactly num_logprobs+1, no truncation needed.""" - processor = _make_processor(num_logprobs=2) - - token_ids = np.array([[7, 11, 13]], dtype=np.int32) - logprobs = np.array([[-0.5, -1.5, -2.5]], dtype=np.float32) - ranks = np.array([1], dtype=np.int32) - - processor._update_sample_logprobs(LogprobsLists(token_ids, logprobs, ranks)) - - pos = processor.logprobs[0] - assert set(pos.keys()) == {7, 11, 13} diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 76834e9bd779..88b1b0b8e8e9 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -690,14 +690,6 @@ def _validate_logprobs(self, model_config: ModelConfig) -> None: parameter="logprob_token_ids", value=n, ) - if self.logprobs is not None and self.logprobs != n: - raise VLLMValidationError( - f"When both logprobs and logprob_token_ids are set, " - f"logprobs must equal len(logprob_token_ids). Got " - f"logprobs={self.logprobs}, len(logprob_token_ids)={n}.", - parameter="logprob_token_ids", - value=n, - ) # Validate prompt logprobs. if num_prompt_logprobs := self.prompt_logprobs: diff --git a/vllm/v1/worker/gpu/sample/logprob.py b/vllm/v1/worker/gpu/sample/logprob.py index cf24c186e93a..7530337fcd12 100644 --- a/vllm/v1/worker/gpu/sample/logprob.py +++ b/vllm/v1/worker/gpu/sample/logprob.py @@ -124,13 +124,9 @@ def compute_topk_logprobs( # tokens where applicable. assert logprob_token_ids_state is not None assert expanded_idx_mapping is not None - + topk_indices = None if num_logprobs > 0: - topk_token_ids = torch.topk(logits, num_logprobs, dim=-1).indices - topk_token_ids = topk_token_ids.to(torch.int32) - else: - # This tensor just used as an int32 pointer, data not accessed. - topk_token_ids = logprob_token_ids_state.token_ids.gpu + topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices num_cols = max(num_logprobs, max_per_req_token_ids) logprob_token_ids = sampled_token_ids.new_zeros((batch_size, 1 + num_cols)) @@ -141,8 +137,8 @@ def compute_topk_logprobs( valid_mask, valid_mask.stride(0), sampled_token_ids, - topk_token_ids, - topk_token_ids.stride(0), + topk_indices if topk_indices is not None else logprob_token_ids, + topk_indices.stride(0) if topk_indices is not None else 0, expanded_idx_mapping, logprob_token_ids_state.num_token_ids.gpu, logprob_token_ids_state.token_ids.gpu, @@ -206,12 +202,14 @@ def _fill_logprob_token_ids_kernel( # Override topk with per-request custom tokens. src = per_req_token_ids_ptr + req_state_idx * per_req_token_ids_stride valid = col < num_custom + # per_req_token_ids is int32; output is int64. + tokens = tl.load(src + col, mask=valid, other=0).to(tl.int64) else: # Fill with topk indices (no-op when NUM_TOPK == 0). src = topk_indices_ptr + batch_idx * topk_indices_stride valid = col < NUM_TOPK + tokens = tl.load(src + col, mask=valid, other=0) - tokens = tl.load(src + col, mask=valid, other=0).to(tl.int64) tl.store(tid_base + col, tokens, mask=valid) tl.store(mask_base + col, tl.full([PADDED_COLS], 1, tl.int1), mask=valid)