Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
_SAMPLING_EPS = 1e-5
_MAX_TEMP = 1e-2

MAX_LOGPROB_TOKEN_IDS = 128
"""Upper bound on `SamplingParams.logprob_token_ids` list length. Must match
the per-request row width allocated by the sampler's `LogprobTokenIdsState`."""


class SamplingType(IntEnum):
GREEDY = 0
Expand Down Expand Up @@ -628,6 +632,16 @@ def bad_words_token_ids(self) -> list[list[int]] | None:
# For internal use only. Backward compatibility not guaranteed
return self._bad_words_token_ids

@property
def num_logprobs(self) -> int | None:
"""Number of sample logprobs to return per output token, or `None` if
no sample logprobs were requested. Takes `logprob_token_ids` into
account: when `logprobs` is unset but `logprob_token_ids` is set,
returns `len(logprob_token_ids)`."""
if self.logprobs is not None:
return self.logprobs
return len(self.logprob_token_ids) if self.logprob_token_ids else None

def clone(self) -> "SamplingParams":
"""If skip_clone is True, uses shallow copy instead of deep copy."""
if self.skip_clone:
Expand Down Expand Up @@ -666,6 +680,17 @@ def _validate_logprobs(self, model_config: ModelConfig) -> None:
value=num_logprobs,
)

# Validate logprob_token_ids.
if self.logprob_token_ids is not None:
n = len(self.logprob_token_ids)
if n > MAX_LOGPROB_TOKEN_IDS:
raise VLLMValidationError(
f"Requested logprob_token_ids of length {n}, "
f"which is greater than max allowed: {MAX_LOGPROB_TOKEN_IDS}",
parameter="logprob_token_ids",
value=n,
)

# Validate prompt logprobs.
if num_prompt_logprobs := self.prompt_logprobs:
if num_prompt_logprobs == -1:
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,7 @@ def update_from_output(
# Extract sample logprobs if needed.
if (
request.sampling_params is not None
and request.sampling_params.logprobs is not None
and request.sampling_params.num_logprobs is not None
and logprobs
):
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def from_new_request(
) -> "LogprobsProcessor":
sampling_params = request.sampling_params
assert sampling_params is not None
num_logprobs = sampling_params.logprobs
num_logprobs = sampling_params.num_logprobs
num_prompt_logprobs = sampling_params.prompt_logprobs
return cls(
tokenizer=tokenizer,
Expand Down
142 changes: 133 additions & 9 deletions vllm/v1/worker/gpu/sample/logprob.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import numpy as np
import torch

from vllm.sampling_params import MAX_LOGPROB_TOKEN_IDS, SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor


@triton.jit
Expand Down Expand Up @@ -75,6 +78,9 @@ def _ranks_kernel(
def compute_token_logprobs(
logits: torch.Tensor, token_ids: torch.Tensor
) -> torch.Tensor:
# NOTE(woosuk): To save GPU memory, we do not materialize the full
# [batch_size, vocab_size] logprobs tensor. The kernel computes
# max + logsumexp per row and only emits logprobs at `token_ids`.
batch_size, vocab_size = logits.shape
token_ids = token_ids.to(torch.int64)
num_logprobs = token_ids.shape[1]
Expand All @@ -97,18 +103,52 @@ def compute_topk_logprobs(
num_logprobs: int,
sampled_token_ids: torch.Tensor,
cu_num_logits: list[int] | None = None,
logprob_token_ids_state: "LogprobTokenIdsState | None" = None,
expanded_idx_mapping: torch.Tensor | None = None,
max_per_req_token_ids: int = 0,
) -> LogprobsTensors:
assert num_logprobs >= 0
batch_size, vocab_size = logits.shape
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
if num_logprobs > 0:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat((logprob_token_ids, topk_indices), dim=1)

# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
logprobs = compute_token_logprobs(logits, logprob_token_ids)

if max_per_req_token_ids == 0:
# Fast path: no request asked for custom logprob_token_ids.
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
if num_logprobs > 0:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat((logprob_token_ids, topk_indices), dim=1)
logprobs = compute_token_logprobs(logits, logprob_token_ids)
else:
# Some requests specified logprob_token_ids. Build the [batch_size,
# 1 + max_cols] token_ids matrix and validity mask on the GPU via a
# single triton kernel, overriding the topk columns with per-request
# tokens where applicable.
assert logprob_token_ids_state is not None
assert expanded_idx_mapping is not None
topk_indices = None
if num_logprobs > 0:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices

num_cols = max(num_logprobs, max_per_req_token_ids)
logprob_token_ids = sampled_token_ids.new_zeros((batch_size, 1 + num_cols))
valid_mask = torch.zeros_like(logprob_token_ids, dtype=torch.bool)
_fill_logprob_token_ids_kernel[(batch_size,)](
logprob_token_ids,
logprob_token_ids.stride(0),
valid_mask,
valid_mask.stride(0),
sampled_token_ids,
topk_indices if topk_indices is not None else logprob_token_ids,
topk_indices.stride(0) if topk_indices is not None else 0,
expanded_idx_mapping,
logprob_token_ids_state.num_token_ids.gpu,
logprob_token_ids_state.token_ids.gpu,
logprob_token_ids_state.token_ids.gpu.stride(0),
NUM_TOPK=num_logprobs,
PADDED_COLS=triton.next_power_of_2(num_cols),
)
logprobs = compute_token_logprobs(logits, logprob_token_ids)
logprobs = logprobs.masked_fill(~valid_mask, float("-inf"))

token_ranks = torch.empty(batch_size, dtype=torch.int64, device=logits.device)
_ranks_kernel[(batch_size,)](
token_ranks,
Expand All @@ -124,3 +164,87 @@ def compute_topk_logprobs(
selected_token_ranks=token_ranks,
cu_num_generated_tokens=cu_num_logits,
)


@triton.jit
def _fill_logprob_token_ids_kernel(
# [batch_size, 1 + num_cols]
out_token_ids_ptr,
out_token_ids_stride,
# [batch_size, 1 + num_cols]
out_valid_mask_ptr,
out_valid_mask_stride,
sampled_token_ids_ptr, # [batch_size]
topk_indices_ptr, # [batch_size, NUM_TOPK] (unused when NUM_TOPK == 0)
topk_indices_stride,
expanded_idx_mapping_ptr, # [batch_size] -> req_state_idx
num_per_req_token_ids_ptr, # [max_num_reqs]
per_req_token_ids_ptr, # [max_num_reqs, MAX_LOGPROB_TOKEN_IDS]
per_req_token_ids_stride,
NUM_TOPK: tl.constexpr,
PADDED_COLS: tl.constexpr,
):
batch_idx = tl.program_id(0)

# Column 0: always the sampled token, always valid.
sampled = tl.load(sampled_token_ids_ptr + batch_idx)
tl.store(out_token_ids_ptr + batch_idx * out_token_ids_stride, sampled)
tl.store(out_valid_mask_ptr + batch_idx * out_valid_mask_stride, 1)

req_state_idx = tl.load(expanded_idx_mapping_ptr + batch_idx)
num_custom = tl.load(num_per_req_token_ids_ptr + req_state_idx)

col = tl.arange(0, PADDED_COLS)
tid_base = out_token_ids_ptr + batch_idx * out_token_ids_stride + 1
mask_base = out_valid_mask_ptr + batch_idx * out_valid_mask_stride + 1

if num_custom > 0:
# Override topk with per-request custom tokens.
src = per_req_token_ids_ptr + req_state_idx * per_req_token_ids_stride
valid = col < num_custom
# per_req_token_ids is int32; output is int64.
tokens = tl.load(src + col, mask=valid, other=0).to(tl.int64)
else:
# Fill with topk indices (no-op when NUM_TOPK == 0).
src = topk_indices_ptr + batch_idx * topk_indices_stride
valid = col < NUM_TOPK
tokens = tl.load(src + col, mask=valid, other=0)

tl.store(tid_base + col, tokens, mask=valid)
tl.store(mask_base + col, tl.full([PADDED_COLS], 1, tl.int1), mask=valid)


class LogprobTokenIdsState:
"""Per-request override of which token ids' logprobs to return.

See `SamplingParams.logprob_token_ids`.
"""

def __init__(self, max_num_reqs: int, device: torch.device):
self.max_num_reqs = max_num_reqs
self.num_token_ids = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
self.token_ids = StagedWriteTensor(
(max_num_reqs, MAX_LOGPROB_TOKEN_IDS),
dtype=torch.int32,
device=device,
)

def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
token_ids = sampling_params.logprob_token_ids
if not token_ids:
self.num_token_ids.np[req_idx] = 0
return
n = len(token_ids)
if n > MAX_LOGPROB_TOKEN_IDS:
raise ValueError(
f"Too many logprob_token_ids: {n}. The max is {MAX_LOGPROB_TOKEN_IDS}."
)
self.num_token_ids.np[req_idx] = n
self.token_ids.stage_write(req_idx, 0, token_ids)

def apply_staged_writes(self) -> None:
self.num_token_ids.copy_to_uva()
self.token_ids.apply_write()

def max_num_token_ids(self, idx_mapping_np: np.ndarray) -> int:
return int(self.num_token_ids.np[idx_mapping_np].max(initial=0))
22 changes: 19 additions & 3 deletions vllm/v1/worker/gpu/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.logprob import (
LogprobTokenIdsState,
compute_topk_logprobs,
)
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
Expand All @@ -38,6 +41,7 @@ def __init__(
self.penalties_state = PenaltiesState(req_states)
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
self.bad_words_state = BadWordsState(req_states)
self.logprob_token_ids_state = LogprobTokenIdsState(max_num_reqs, device)
self.num_speculative_tokens = num_speculative_tokens

def add_request(
Expand All @@ -47,12 +51,14 @@ def add_request(
self.penalties_state.add_request(req_idx, sampling_params)
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
self.bad_words_state.add_request(req_idx, sampling_params)
self.logprob_token_ids_state.add_request(req_idx, sampling_params)

def apply_staged_writes(self) -> None:
self.sampling_states.apply_staged_writes()
self.penalties_state.apply_staged_writes()
self.logit_bias_state.apply_staged_writes()
self.bad_words_state.apply_staged_writes()
self.logprob_token_ids_state.apply_staged_writes()

def __call__(
self,
Expand All @@ -79,13 +85,23 @@ def __call__(
)

max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
if max_num_logprobs != NO_LOGPROBS:
max_per_req_token_ids = self.logprob_token_ids_state.max_num_token_ids(
idx_mapping_np
)
if max_num_logprobs != NO_LOGPROBS or max_per_req_token_ids > 0:
if self.logprobs_mode == "processed_logprobs":
logits = processed_logits
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
num_logprobs = max_num_logprobs if max_num_logprobs != NO_LOGPROBS else 0
logprobs_tensors = compute_topk_logprobs(
logits, max_num_logprobs, sampled, cu_num_logits
logits,
num_logprobs,
sampled,
cu_num_logits,
logprob_token_ids_state=self.logprob_token_ids_state,
expanded_idx_mapping=input_batch.expanded_idx_mapping,
max_per_req_token_ids=max_per_req_token_ids,
)
else:
logprobs_tensors = None
Expand Down
Loading