diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 77fa6402180e..88b1b0b8e8e9 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -25,6 +25,10 @@ _SAMPLING_EPS = 1e-5 _MAX_TEMP = 1e-2 +MAX_LOGPROB_TOKEN_IDS = 128 +"""Upper bound on `SamplingParams.logprob_token_ids` list length. Must match +the per-request row width allocated by the sampler's `LogprobTokenIdsState`.""" + class SamplingType(IntEnum): GREEDY = 0 @@ -628,6 +632,16 @@ def bad_words_token_ids(self) -> list[list[int]] | None: # For internal use only. Backward compatibility not guaranteed return self._bad_words_token_ids + @property + def num_logprobs(self) -> int | None: + """Number of sample logprobs to return per output token, or `None` if + no sample logprobs were requested. Takes `logprob_token_ids` into + account: when `logprobs` is unset but `logprob_token_ids` is set, + returns `len(logprob_token_ids)`.""" + if self.logprobs is not None: + return self.logprobs + return len(self.logprob_token_ids) if self.logprob_token_ids else None + def clone(self) -> "SamplingParams": """If skip_clone is True, uses shallow copy instead of deep copy.""" if self.skip_clone: @@ -666,6 +680,17 @@ def _validate_logprobs(self, model_config: ModelConfig) -> None: value=num_logprobs, ) + # Validate logprob_token_ids. + if self.logprob_token_ids is not None: + n = len(self.logprob_token_ids) + if n > MAX_LOGPROB_TOKEN_IDS: + raise VLLMValidationError( + f"Requested logprob_token_ids of length {n}, " + f"which is greater than max allowed: {MAX_LOGPROB_TOKEN_IDS}", + parameter="logprob_token_ids", + value=n, + ) + # Validate prompt logprobs. if num_prompt_logprobs := self.prompt_logprobs: if num_prompt_logprobs == -1: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 395fa80bfe53..0b97e46db95f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1448,7 +1448,7 @@ def update_from_output( # Extract sample logprobs if needed. if ( request.sampling_params is not None - and request.sampling_params.logprobs is not None + and request.sampling_params.num_logprobs is not None and logprobs ): new_logprobs = logprobs.slice_request(req_index, len(new_token_ids)) diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 9ada6eda48ce..74a45ab1e4d4 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -47,7 +47,7 @@ def from_new_request( ) -> "LogprobsProcessor": sampling_params = request.sampling_params assert sampling_params is not None - num_logprobs = sampling_params.logprobs + num_logprobs = sampling_params.num_logprobs num_prompt_logprobs = sampling_params.prompt_logprobs return cls( tokenizer=tokenizer, diff --git a/vllm/v1/worker/gpu/sample/logprob.py b/vllm/v1/worker/gpu/sample/logprob.py index 4317cad9ce7f..7530337fcd12 100644 --- a/vllm/v1/worker/gpu/sample/logprob.py +++ b/vllm/v1/worker/gpu/sample/logprob.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np import torch +from vllm.sampling_params import MAX_LOGPROB_TOKEN_IDS, SamplingParams from vllm.triton_utils import tl, triton from vllm.v1.outputs import LogprobsTensors +from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor @triton.jit @@ -75,6 +78,9 @@ def _ranks_kernel( def compute_token_logprobs( logits: torch.Tensor, token_ids: torch.Tensor ) -> torch.Tensor: + # NOTE(woosuk): To save GPU memory, we do not materialize the full + # [batch_size, vocab_size] logprobs tensor. The kernel computes + # max + logsumexp per row and only emits logprobs at `token_ids`. batch_size, vocab_size = logits.shape token_ids = token_ids.to(torch.int64) num_logprobs = token_ids.shape[1] @@ -97,18 +103,52 @@ def compute_topk_logprobs( num_logprobs: int, sampled_token_ids: torch.Tensor, cu_num_logits: list[int] | None = None, + logprob_token_ids_state: "LogprobTokenIdsState | None" = None, + expanded_idx_mapping: torch.Tensor | None = None, + max_per_req_token_ids: int = 0, ) -> LogprobsTensors: assert num_logprobs >= 0 batch_size, vocab_size = logits.shape - logprob_token_ids = sampled_token_ids.unsqueeze(-1) - if num_logprobs > 0: - topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices - logprob_token_ids = torch.cat((logprob_token_ids, topk_indices), dim=1) - - # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full - # logprobs tensor. Instead, we only compute and return the logprobs of - # the topk + 1 tokens. - logprobs = compute_token_logprobs(logits, logprob_token_ids) + + if max_per_req_token_ids == 0: + # Fast path: no request asked for custom logprob_token_ids. + logprob_token_ids = sampled_token_ids.unsqueeze(-1) + if num_logprobs > 0: + topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices + logprob_token_ids = torch.cat((logprob_token_ids, topk_indices), dim=1) + logprobs = compute_token_logprobs(logits, logprob_token_ids) + else: + # Some requests specified logprob_token_ids. Build the [batch_size, + # 1 + max_cols] token_ids matrix and validity mask on the GPU via a + # single triton kernel, overriding the topk columns with per-request + # tokens where applicable. + assert logprob_token_ids_state is not None + assert expanded_idx_mapping is not None + topk_indices = None + if num_logprobs > 0: + topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices + + num_cols = max(num_logprobs, max_per_req_token_ids) + logprob_token_ids = sampled_token_ids.new_zeros((batch_size, 1 + num_cols)) + valid_mask = torch.zeros_like(logprob_token_ids, dtype=torch.bool) + _fill_logprob_token_ids_kernel[(batch_size,)]( + logprob_token_ids, + logprob_token_ids.stride(0), + valid_mask, + valid_mask.stride(0), + sampled_token_ids, + topk_indices if topk_indices is not None else logprob_token_ids, + topk_indices.stride(0) if topk_indices is not None else 0, + expanded_idx_mapping, + logprob_token_ids_state.num_token_ids.gpu, + logprob_token_ids_state.token_ids.gpu, + logprob_token_ids_state.token_ids.gpu.stride(0), + NUM_TOPK=num_logprobs, + PADDED_COLS=triton.next_power_of_2(num_cols), + ) + logprobs = compute_token_logprobs(logits, logprob_token_ids) + logprobs = logprobs.masked_fill(~valid_mask, float("-inf")) + token_ranks = torch.empty(batch_size, dtype=torch.int64, device=logits.device) _ranks_kernel[(batch_size,)]( token_ranks, @@ -124,3 +164,87 @@ def compute_topk_logprobs( selected_token_ranks=token_ranks, cu_num_generated_tokens=cu_num_logits, ) + + +@triton.jit +def _fill_logprob_token_ids_kernel( + # [batch_size, 1 + num_cols] + out_token_ids_ptr, + out_token_ids_stride, + # [batch_size, 1 + num_cols] + out_valid_mask_ptr, + out_valid_mask_stride, + sampled_token_ids_ptr, # [batch_size] + topk_indices_ptr, # [batch_size, NUM_TOPK] (unused when NUM_TOPK == 0) + topk_indices_stride, + expanded_idx_mapping_ptr, # [batch_size] -> req_state_idx + num_per_req_token_ids_ptr, # [max_num_reqs] + per_req_token_ids_ptr, # [max_num_reqs, MAX_LOGPROB_TOKEN_IDS] + per_req_token_ids_stride, + NUM_TOPK: tl.constexpr, + PADDED_COLS: tl.constexpr, +): + batch_idx = tl.program_id(0) + + # Column 0: always the sampled token, always valid. + sampled = tl.load(sampled_token_ids_ptr + batch_idx) + tl.store(out_token_ids_ptr + batch_idx * out_token_ids_stride, sampled) + tl.store(out_valid_mask_ptr + batch_idx * out_valid_mask_stride, 1) + + req_state_idx = tl.load(expanded_idx_mapping_ptr + batch_idx) + num_custom = tl.load(num_per_req_token_ids_ptr + req_state_idx) + + col = tl.arange(0, PADDED_COLS) + tid_base = out_token_ids_ptr + batch_idx * out_token_ids_stride + 1 + mask_base = out_valid_mask_ptr + batch_idx * out_valid_mask_stride + 1 + + if num_custom > 0: + # Override topk with per-request custom tokens. + src = per_req_token_ids_ptr + req_state_idx * per_req_token_ids_stride + valid = col < num_custom + # per_req_token_ids is int32; output is int64. + tokens = tl.load(src + col, mask=valid, other=0).to(tl.int64) + else: + # Fill with topk indices (no-op when NUM_TOPK == 0). + src = topk_indices_ptr + batch_idx * topk_indices_stride + valid = col < NUM_TOPK + tokens = tl.load(src + col, mask=valid, other=0) + + tl.store(tid_base + col, tokens, mask=valid) + tl.store(mask_base + col, tl.full([PADDED_COLS], 1, tl.int1), mask=valid) + + +class LogprobTokenIdsState: + """Per-request override of which token ids' logprobs to return. + + See `SamplingParams.logprob_token_ids`. + """ + + def __init__(self, max_num_reqs: int, device: torch.device): + self.max_num_reqs = max_num_reqs + self.num_token_ids = UvaBackedTensor(max_num_reqs, dtype=torch.int32) + self.token_ids = StagedWriteTensor( + (max_num_reqs, MAX_LOGPROB_TOKEN_IDS), + dtype=torch.int32, + device=device, + ) + + def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None: + token_ids = sampling_params.logprob_token_ids + if not token_ids: + self.num_token_ids.np[req_idx] = 0 + return + n = len(token_ids) + if n > MAX_LOGPROB_TOKEN_IDS: + raise ValueError( + f"Too many logprob_token_ids: {n}. The max is {MAX_LOGPROB_TOKEN_IDS}." + ) + self.num_token_ids.np[req_idx] = n + self.token_ids.stage_write(req_idx, 0, token_ids) + + def apply_staged_writes(self) -> None: + self.num_token_ids.copy_to_uva() + self.token_ids.apply_write() + + def max_num_token_ids(self, idx_mapping_np: np.ndarray) -> int: + return int(self.num_token_ids.np[idx_mapping_np].max(initial=0)) diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 6f73ca87ac67..5d91d5b2f097 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -12,7 +12,10 @@ from vllm.v1.worker.gpu.sample.bad_words import BadWordsState from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState -from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs +from vllm.v1.worker.gpu.sample.logprob import ( + LogprobTokenIdsState, + compute_topk_logprobs, +) from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.penalties import PenaltiesState from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates @@ -38,6 +41,7 @@ def __init__( self.penalties_state = PenaltiesState(req_states) self.logit_bias_state = LogitBiasState(max_num_reqs, device) self.bad_words_state = BadWordsState(req_states) + self.logprob_token_ids_state = LogprobTokenIdsState(max_num_reqs, device) self.num_speculative_tokens = num_speculative_tokens def add_request( @@ -47,12 +51,14 @@ def add_request( self.penalties_state.add_request(req_idx, sampling_params) self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params) self.bad_words_state.add_request(req_idx, sampling_params) + self.logprob_token_ids_state.add_request(req_idx, sampling_params) def apply_staged_writes(self) -> None: self.sampling_states.apply_staged_writes() self.penalties_state.apply_staged_writes() self.logit_bias_state.apply_staged_writes() self.bad_words_state.apply_staged_writes() + self.logprob_token_ids_state.apply_staged_writes() def __call__( self, @@ -79,13 +85,23 @@ def __call__( ) max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) - if max_num_logprobs != NO_LOGPROBS: + max_per_req_token_ids = self.logprob_token_ids_state.max_num_token_ids( + idx_mapping_np + ) + if max_num_logprobs != NO_LOGPROBS or max_per_req_token_ids > 0: if self.logprobs_mode == "processed_logprobs": logits = processed_logits expanded_logits = logits.shape[0] != idx_mapping_np.shape[0] cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None + num_logprobs = max_num_logprobs if max_num_logprobs != NO_LOGPROBS else 0 logprobs_tensors = compute_topk_logprobs( - logits, max_num_logprobs, sampled, cu_num_logits + logits, + num_logprobs, + sampled, + cu_num_logits, + logprob_token_ids_state=self.logprob_token_ids_state, + expanded_idx_mapping=input_batch.expanded_idx_mapping, + max_per_req_token_ids=max_per_req_token_ids, ) else: logprobs_tensors = None