Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
)
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
req.logprob_start_len = -1
reqs.append(req)

return input_ids, reqs
Expand All @@ -318,7 +318,7 @@ def prepare_extend_inputs_for_correctness_test(
i, : bench_args.cut_len
]
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
req.logprob_start_len = -1
return reqs


Expand All @@ -345,7 +345,7 @@ def prepare_synthetic_inputs_for_latency_test(
)
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
req.logprob_start_len = -1
reqs.append(req)

return reqs
Expand Down
41 changes: 18 additions & 23 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
input_len = len(self.fill_ids)
# NOTE: the matched length is at most 1 less than the input length to enable logprob computation
max_prefix_len = input_len - 1
if self.return_logprob:
if self.return_logprob and self.logprob_start_len >= 0:
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
max_prefix_len = max(max_prefix_len, 0)
token_ids = self.fill_ids[:max_prefix_len]
Expand Down Expand Up @@ -1120,6 +1120,7 @@ def set_finish_with_abort(self, error_msg: str):
self.grammar = None
self.origin_input_ids = [0] # set it to one token to skip the long prefill
self.return_logprob = False
self.logprob_start_len = -1
self.to_finish = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
Expand Down Expand Up @@ -1490,26 +1491,16 @@ def prepare_for_extend(self):
# (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
# and prefix_indices are the cached/shared prefix tokens)
#
if req.logprob_start_len >= pre_len:
# Optimization for prefill-only requests: When we only need logprobs at
# positions beyond the input sequence (to score next-token likelihood), skip all
# input logprob computation during prefill since no generation will occur.
if self.is_prefill_only and req.logprob_start_len == len(
req.origin_input_ids
):
# Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
req.extend_logprob_start_len = req.extend_input_len
else:
# Convert absolute logprob_start_len to relative extend_logprob_start_len
#
# Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
# Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
# This means: "compute logprobs from position 3 onwards in extend batch"
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
req.seqlen - 1,
)
if req.logprob_start_len == -1:
req.extend_logprob_start_len = min(
len(req.fill_ids) - 1 - pre_len,
req.extend_input_len,
)
elif req.logprob_start_len >= pre_len:
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
)
else:
# logprob_start_len is before the current extend batch, so start from beginning
req.extend_logprob_start_len = 0
Expand All @@ -1532,9 +1523,13 @@ def prepare_for_extend(self):
len(req.prefix_indices),
len(req.fill_ids),
)
if req.logprob_start_len == -1:
logprob_start_len = len(req.origin_input_ids) - 1
else:
logprob_start_len = req.logprob_start_len
# Apply logprob_start_len
if global_start_idx < req.logprob_start_len:
global_start_idx = req.logprob_start_len
if global_start_idx < logprob_start_len:
global_start_idx = logprob_start_len

logprob_token_ids = req.origin_input_ids[
global_start_idx + 1 : global_end_idx + 1
Expand Down
23 changes: 11 additions & 12 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,24 +1524,23 @@ def handle_generate_request(
self._add_request_to_queue(req)
return

# Copy more attributes
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
# By default, only return the logprobs for output tokens
# For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
# to skip input logprob computation entirely
if recv_req.logprob_start_len == -1:
if req.is_prefill_only:
# For prefill-only requests with logprob_start_len == -1, set logprob_start_len
# beyond input sequence to skip input logprob computation entirely
req.logprob_start_len = len(req.origin_input_ids)
else:
# TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
elif recv_req.return_logprob:
# If return_logprob is True, return the logprobs for output tokens by default
req.logprob_start_len = len(req.origin_input_ids) - 1
else:
# If return_logprob is False, only the last token requires logprob computation
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous code, we didn't propogate logprob_start_len = -1 beyond the Scheduler. The value was reset to len(req.origin_input_ids) - 1. Hence, for prefill-only, we followed the same way.

Given, with this change we are propogating logprob_start_len = -1 into scheduler_batch, we can do the same logic for prefill_only as well instead of resetting to len(req.origin_input_ids) in line: 1531.

But, looks like, the current change doesn't break the prefill-only logprobs computation, so we are good.

req.logprob_start_len = -1
else:
req.logprob_start_len = recv_req.logprob_start_len

if not req.is_prefill_only and req.logprob_start_len >= len(
req.origin_input_ids
):
if req.logprob_start_len > len(req.origin_input_ids):
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
req.logprob_start_len = len(req.origin_input_ids) - 1
req.logprob_start_len = -1
req.set_finish_with_abort(error_msg)
self._add_request_to_queue(req)
return
Expand Down Expand Up @@ -1760,7 +1759,7 @@ def handle_embedding_request(
return

# Copy more attributes
req.logprob_start_len = len(req.origin_input_ids) - 1
req.logprob_start_len = -1
self._add_request_to_queue(req)

def handle_batch_embedding_request(
Expand Down
22 changes: 11 additions & 11 deletions python/sglang/srt/managers/scheduler_dp_attn_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,18 @@ def prepare_mlp_sync_batch_raw(
num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
if local_batch.return_logprob:
num_tokens_for_logprob = sum(
# We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1)
for logprob_start_len, extend_len in zip(
local_batch.extend_logprob_start_lens,
local_batch.extend_lens,
)
num_tokens_for_logprob = sum(
# We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1)
for logprob_start_len, extend_len in zip(
local_batch.extend_logprob_start_lens,
local_batch.extend_lens,
)
else:
# When return_logprob = False, only need last token per request
num_tokens_for_logprob = local_batch.batch_size()
)
assert (
local_batch.return_logprob
or num_tokens_for_logprob == local_batch.batch_size()
)

skip_all_gather = envs.SGLANG_SCHEDULER_SKIP_ALL_GATHER.get()
can_cuda_graph = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,18 +591,18 @@ def _calculate_relevant_tokens_len(self, req: Req) -> int:
For regular requests, all positions from logprob_start_len onwards have logprobs.
"""
is_multi_item_scoring = self._is_multi_item_scoring(req)
relevant_tokens = req.origin_input_ids[req.logprob_start_len :]

if is_multi_item_scoring:
# Multi-item scoring: count delimiter tokens from logprob_start_len onwards
relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
return sum(
1
for token_id in relevant_tokens
if token_id == self.server_args.multi_item_scoring_delimiter
)
else:
# Regular request: all tokens from logprob_start_len onwards
return len(req.origin_input_ids) - req.logprob_start_len
return len(relevant_tokens)

def _calculate_num_input_logprobs(
self, req: Req, extend_input_len: int, extend_logprob_start_len: int
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/scheduler_pp_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def profile_and_init_predictor(self: Scheduler):
)
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
req.logprob_start_len = -1

# Prepare batch
batch = ScheduleBatch.init_new(
Expand Down
2 changes: 1 addition & 1 deletion test/manual/test_forward_split_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def prepare_test_batch(self, batch_size=2, input_len=128, is_split_prefill=True)
)
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
req.logprob_start_len = -1
reqs.append(req)

# Create dummy tree_cache for tests (no prefix caching, just allocation)
Expand Down
37 changes: 37 additions & 0 deletions test/srt/test_dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import requests

from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
Expand Down Expand Up @@ -58,6 +60,41 @@ def test_mgsm_en(self):
self.assertGreater(metrics["score"], 0.8)


class TestDPRetract(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--enable-dp-attention",
"--dp",
"2",
"--max-total-tokens",
"4500",
"--max-running-requests",
"128",
"--chunked-prefill-size",
"256",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_radix_attention(self):
with envs.SGLANG_TEST_RETRACT.override(True):
run_radix_attention_test(self.base_url)
self.assertIsNone(self.process.poll())


class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
@classmethod
def setUpClass(cls):
Expand Down
Loading