From fff44543ddc70da40a99013f21a4d7da00990a96 Mon Sep 17 00:00:00 2001 From: Yangmin Li Date: Sat, 25 Oct 2025 01:29:34 -0700 Subject: [PATCH 1/2] fix num_tokens_for_logprob calculation when return_logprob=False --- python/sglang/srt/managers/scheduler.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2938511112a0..4dc176a2576a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2073,15 +2073,19 @@ def prepare_mlp_sync_batch_raw( num_tokens_for_logprob = num_tokens else: num_tokens = local_batch.extend_num_tokens - num_tokens_for_logprob = sum( - [ - # We should have at least 1 token for sample in every case. - max(extend_len - logprob_start_len, 1) - for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) - ] - ) + if local_batch.return_logprob: + num_tokens_for_logprob = sum( + [ + # We should have at least 1 token for sample in every case. + max(extend_len - logprob_start_len, 1) + for logprob_start_len, extend_len in zip( + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) + ] + ) + else: + # When return_logprob = False, only need last token per request + num_tokens_for_logprob = local_batch.batch_size() if local_batch is None or local_batch.forward_mode.is_decode_or_idle(): can_cuda_graph = 1 From c2203c41d8b95cb824e4e21133436a6aedaf2313 Mon Sep 17 00:00:00 2001 From: Yangmin Li Date: Sat, 25 Oct 2025 01:51:08 -0700 Subject: [PATCH 2/2] fix format & gemini suggestion --- python/sglang/srt/managers/scheduler.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4dc176a2576a..c1ae06612867 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2075,13 +2075,12 @@ def prepare_mlp_sync_batch_raw( num_tokens = local_batch.extend_num_tokens if local_batch.return_logprob: num_tokens_for_logprob = sum( - [ - # We should have at least 1 token for sample in every case. - max(extend_len - logprob_start_len, 1) - for logprob_start_len, extend_len in zip( - local_batch.extend_logprob_start_lens, local_batch.extend_lens - ) - ] + # We should have at least 1 token for sample in every case. + max(extend_len - logprob_start_len, 1) + for logprob_start_len, extend_len in zip( + local_batch.extend_logprob_start_lens, + local_batch.extend_lens, + ) ) else: # When return_logprob = False, only need last token per request