diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2938511112a0..c1ae06612867 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2073,15 +2073,18 @@ def prepare_mlp_sync_batch_raw( num_tokens_for_logprob = num_tokens else: num_tokens = local_batch.extend_num_tokens - num_tokens_for_logprob = sum( - [ + if local_batch.return_logprob: + num_tokens_for_logprob = sum( # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens + local_batch.extend_logprob_start_lens, + local_batch.extend_lens, ) - ] - ) + ) + else: + # When return_logprob = False, only need last token per request + num_tokens_for_logprob = local_batch.batch_size() if local_batch is None or local_batch.forward_mode.is_decode_or_idle(): can_cuda_graph = 1