diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 1573ca68da77..9734d7cbaf12 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -281,13 +281,17 @@ def select( # Remove extra token if no token healing occurred for i in range(len(input_token_logprobs)): + # Skip if no logprobs available (can happen on some backends) + if not input_token_logprobs[i] or not input_token_logprobs[i][0]: + continue healed_token_str = input_token_logprobs[i][0][-1] if s.text_.endswith(healed_token_str): healed_token_logprob = input_token_logprobs[i][0][0] - normalized_prompt_logprobs[i] = ( - normalized_prompt_logprobs[i] * len(input_token_logprobs[i]) - - healed_token_logprob - ) / (len(input_token_logprobs[i]) - 1) + num_tokens = len(input_token_logprobs[i]) + if num_tokens > 1: + normalized_prompt_logprobs[i] = ( + normalized_prompt_logprobs[i] * num_tokens - healed_token_logprob + ) / (num_tokens - 1) input_token_logprobs[i] = input_token_logprobs[i][1:] # Compute unconditional logprobs if required @@ -349,6 +353,9 @@ def _assert_success(self, res): def compute_normalized_prompt_logprobs(input_logprobs): values = [x[0] for x in input_logprobs if x[0]] + if not values: + # Return negative infinity if no valid logprobs - this choice should not be selected + return float("-inf") return sum(values) / len(values)