Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ class LogitsProcessorOutput:
next_token_top_logprobs_val: Optional[List] = None
next_token_top_logprobs_idx: Optional[List] = None
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
next_token_token_ids_logprobs_val: Optional[List] = None
# Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests)
next_token_token_ids_logprobs_val: Optional[
List[Union[List[float], torch.Tensor]]
] = None
next_token_token_ids_logprobs_idx: Optional[List] = None

## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
Expand Down
171 changes: 157 additions & 14 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List
from typing import List, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -39,6 +39,25 @@ def __init__(self):
if is_dp_attention_enabled():
self.tp_sync_group = get_attention_tp_group().device_group

def _preprocess_logits(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
) -> torch.Tensor:
"""Apply custom logit processors and handle NaN detection."""
# Apply the custom logit processors if registered in the sampling info
if sampling_info.has_custom_logit_processor:
apply_custom_logit_processor(logits, sampling_info)

# Detect and handle NaN values in logits
if self.use_nan_detection and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits
)
if crash_on_warnings():
raise ValueError("Detected errors during sampling! NaN in the logits.")

return logits

def forward(
self,
logits_output: LogitsProcessorOutput,
Expand All @@ -61,17 +80,8 @@ def forward(
"""
logits = logits_output.next_token_logits

# Apply the custom logit processors if registered in the sampling info.
if sampling_info.has_custom_logit_processor:
apply_custom_logit_processor(logits, sampling_info)

if self.use_nan_detection and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits
)
if crash_on_warnings():
raise ValueError("Detected errors during sampling! NaN in the logits.")
# Preprocess logits (custom processors and NaN handling)
logits = self._preprocess_logits(logits, sampling_info)

if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
Expand Down Expand Up @@ -164,6 +174,54 @@ def forward(

return batch_next_token_ids

def compute_logprobs_only(
self,
logits_output: LogitsProcessorOutput,
sampling_info: SamplingBatchInfo,
return_logprob: bool,
top_logprobs_nums: List[int],
token_ids_logprobs: List[List[int]],
) -> None:
"""
Compute logprobs for requested token IDs without performing sampling.

Optimized for prefill-only scoring requests that need token probabilities
but don't require next token generation.
"""
if logits_output.next_token_logits is None:
logger.warning("No logits available for logprob computation")
return

# Check if any requests actually need logprobs computation
needs_token_ids_logprobs = any(
token_ids is not None and len(token_ids) > 0
for token_ids in token_ids_logprobs
)
needs_top_logprobs = any(x > 0 for x in top_logprobs_nums)

if not (needs_token_ids_logprobs or needs_top_logprobs):
return

# Preprocess logits (custom processors and NaN handling)
logits = self._preprocess_logits(logits_output.next_token_logits, sampling_info)

# Compute logprobs
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)

# Handle top logprobs if requested
if needs_top_logprobs:
(
logits_output.next_token_top_logprobs_val,
logits_output.next_token_top_logprobs_idx,
) = get_top_logprobs(logprobs, top_logprobs_nums)

# Handle token_ids logprobs if requested
if needs_token_ids_logprobs:
(
logits_output.next_token_token_ids_logprobs_val,
logits_output.next_token_token_ids_logprobs_idx,
) = get_token_ids_logprobs_batch_optimized(logprobs, token_ids_logprobs)


def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor,
Expand Down Expand Up @@ -233,10 +291,95 @@ def get_top_logprobs(
)


def get_token_ids_logprobs(
def get_token_ids_logprobs_batch_optimized(
logprobs: torch.Tensor,
token_ids_logprobs: List[List[int]],
):
) -> Tuple[List, List]:
"""
Vectorized batch processing for token ID logprobs extraction.

Uses a single GPU kernel call for the entire batch instead of multiple
separate calls, significantly improving performance for large batches.

Args:
logprobs: Log probabilities tensor [batch_size, vocab_size]
token_ids_logprobs: List of token IDs to extract logprobs for

Example:
# Input: batch_size=3, vocab_size=5
logprobs = torch.tensor([
[-1.2, -2.1, -0.8, -3.0, -1.5], # batch 0
[-0.5, -1.8, -2.2, -1.1, -2.7], # batch 1
[-2.0, -0.9, -1.4, -2.8, -1.6], # batch 2
])
token_ids_logprobs = [[1, 3], [2], [0, 2, 4]]

# Output:
# values = [tensor([-2.1, -3.0]), tensor([-2.2]), tensor([-2.0, -1.4, -1.6])]
# indices = [[1, 3], [2], [0, 2, 4]]
"""
batch_size = len(token_ids_logprobs)
device = logprobs.device

# Step 1: Calculate lengths for each request, treating None as empty list
# Example: [[1, 3], [2], [0, 2, 4]] -> token_lengths = tensor([2, 1, 3])
token_lengths = torch.tensor(
[len(token_ids or []) for token_ids in token_ids_logprobs], device=device
)
total_tokens = int(token_lengths.sum().item()) # 2 + 1 + 3 = 6

# Handle edge case where no tokens are requested
if total_tokens == 0:
return [logprobs.new_empty(0) for _ in token_ids_logprobs], [
[] for _ in token_ids_logprobs
]

# Step 2: Build flattened indices using torch operations
# Example: row_indices = [0, 0, 1, 2, 2, 2] (batch indices repeated by their lengths)
row_indices = torch.repeat_interleave(
torch.arange(batch_size, device=device), token_lengths
)
# Example: col_indices = [1, 3, 2, 0, 2, 4] (flattened token IDs from all requests)
col_indices = torch.tensor(
[
token_id
for token_ids in token_ids_logprobs
for token_id in (token_ids or [])
],
device=device,
dtype=torch.long,
)

# Step 3: Single vectorized gather operation
# Example: logprobs[row_indices, col_indices] -> [-2.1, -3.0, -2.2, -2.0, -1.4, -1.6]
gathered_logprobs = logprobs[row_indices, col_indices]

# Step 4: Split results back per request using torch operations
# Example: split tensor [6] into chunks of sizes [2, 1, 3] -> [tensor(2), tensor(1), tensor(3)]
split_logprobs = torch.split_with_sizes(
gathered_logprobs, token_lengths.tolist(), dim=0
)

# Step 5: Format output to match expected return structure
# Example: Convert split tensors back to list format with proper empty handling
# i=0: [1,3] -> append split_logprobs[0] and [1,3]
# i=1: [2] -> append split_logprobs[1] and [2]
# i=2: [0,2,4] -> append split_logprobs[2] and [0,2,4]
output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = []

for i, token_ids in enumerate(token_ids_logprobs):
if token_ids is not None and len(token_ids) > 0:
output_token_ids_logprobs_val.append(split_logprobs[i])
output_token_ids_logprobs_idx.append(token_ids)
else:
output_token_ids_logprobs_val.append(logprobs.new_empty(0))
output_token_ids_logprobs_idx.append([])

return output_token_ids_logprobs_val, output_token_ids_logprobs_idx


def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = []
for i, token_ids in enumerate(token_ids_logprobs):
Expand Down
51 changes: 42 additions & 9 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,10 @@ def __init__(
# shape: (bs, k)
self.output_top_logprobs_val = []
self.output_top_logprobs_idx = []
self.output_token_ids_logprobs_val = []
# Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
self.output_token_ids_logprobs_val: List[
Union[List[float], torch.Tensor]
] = []
self.output_token_ids_logprobs_idx = []
else:
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
Expand Down Expand Up @@ -619,6 +622,11 @@ def __init__(
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)

@property
def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed)."""
return self.sampling_params.max_new_tokens == 0

def extend_image_inputs(self, image_inputs):
if self.multimodal_inputs is None:
self.multimodal_inputs = image_inputs
Expand Down Expand Up @@ -950,9 +958,7 @@ def init_new(
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
return_hidden_states=any(req.return_hidden_states for req in reqs),
is_prefill_only=all(
req.sampling_params.max_new_tokens == 0 for req in reqs
),
is_prefill_only=all(req.is_prefill_only for req in reqs),
chunked_req=chunked_req,
)

Expand Down Expand Up @@ -1210,13 +1216,36 @@ def prepare_for_extend(self):
req.is_retracted = False

# Compute the relative logprob_start_len in an extend batch
#
# Key variables:
# - logprob_start_len: Absolute position in full sequence where logprob computation begins
# - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
# - extend_input_len: Number of tokens that need to be processed in this extend batch
# (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
# and prefix_indices are the cached/shared prefix tokens)
#
if req.logprob_start_len >= pre_len:
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
req.seqlen - 1,
)
# Optimization for prefill-only requests: When we only need logprobs at
# positions beyond the input sequence (to score next-token likelihood), skip all
# input logprob computation during prefill since no generation will occur.
if self.is_prefill_only and req.logprob_start_len == len(
req.origin_input_ids
):
# Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
req.extend_logprob_start_len = req.extend_input_len
else:
# Convert absolute logprob_start_len to relative extend_logprob_start_len
#
# Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
# Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
# This means: "compute logprobs from position 3 onwards in extend batch"
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
req.seqlen - 1,
)
else:
# logprob_start_len is before the current extend batch, so start from beginning
req.extend_logprob_start_len = 0

if self.return_logprob:
Expand Down Expand Up @@ -1763,6 +1792,7 @@ def get_model_worker_batch(
),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
launch_done=self.launch_done,
is_prefill_only=self.is_prefill_only,
)

def copy(self):
Expand Down Expand Up @@ -1905,6 +1935,9 @@ class ModelWorkerBatch:
# Overlap event
launch_done: Optional[threading.Event] = None

# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False


@triton.jit
def write_req_to_token_pool_triton(
Expand Down
12 changes: 10 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,11 +1261,19 @@ def handle_generate_request(
# Copy more attributes
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1
# For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
# to skip input logprob computation entirely
if req.is_prefill_only:
req.logprob_start_len = len(req.origin_input_ids)
else:
# TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
req.logprob_start_len = len(req.origin_input_ids) - 1
else:
req.logprob_start_len = recv_req.logprob_start_len

if req.logprob_start_len >= len(req.origin_input_ids):
if not req.is_prefill_only and req.logprob_start_len >= len(
req.origin_input_ids
):
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
req.logprob_start_len = len(req.origin_input_ids) - 1
req.set_finish_with_abort(error_msg)
Expand Down
Loading
Loading