Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions python/sglang/lang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,17 @@ def select(

# Remove extra token if no token healing occurred
for i in range(len(input_token_logprobs)):
# Skip if no logprobs available (can happen on some backends)
if not input_token_logprobs[i] or not input_token_logprobs[i][0]:
continue
healed_token_str = input_token_logprobs[i][0][-1]
if s.text_.endswith(healed_token_str):
healed_token_logprob = input_token_logprobs[i][0][0]
normalized_prompt_logprobs[i] = (
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
- healed_token_logprob
) / (len(input_token_logprobs[i]) - 1)
num_tokens = len(input_token_logprobs[i])
if num_tokens > 1:
normalized_prompt_logprobs[i] = (
normalized_prompt_logprobs[i] * num_tokens - healed_token_logprob
) / (num_tokens - 1)
Comment on lines +291 to +294
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While this change correctly prevents a ZeroDivisionError when num_tokens is 1, it leaves normalized_prompt_logprobs[i] with a stale value. After the healed token is removed on line 295, input_token_logprobs[i] becomes empty, but the corresponding normalized_prompt_logprobs[i] is not updated, leading to an inconsistent state. To align with the logic in compute_normalized_prompt_logprobs which returns -inf for empty logprobs, it would be more correct to set normalized_prompt_logprobs[i] to -inf when num_tokens is 1.

Suggested change
if num_tokens > 1:
normalized_prompt_logprobs[i] = (
normalized_prompt_logprobs[i] * num_tokens - healed_token_logprob
) / (num_tokens - 1)
if num_tokens > 1:
normalized_prompt_logprobs[i] = (
normalized_prompt_logprobs[i] * num_tokens - healed_token_logprob
) / (num_tokens - 1)
else:
normalized_prompt_logprobs[i] = float("-inf")

input_token_logprobs[i] = input_token_logprobs[i][1:]

# Compute unconditional logprobs if required
Expand Down Expand Up @@ -349,6 +353,9 @@ def _assert_success(self, res):

def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]]
if not values:
# Return negative infinity if no valid logprobs - this choice should not be selected
return float("-inf")
return sum(values) / len(values)


Expand Down
Loading