Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions scripts/playground/bench_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import requests

from sglang.bench_serving import benchmark, set_global_args
from sglang.bench_serving import DatasetRow, benchmark, set_global_args
from sglang.srt.server_args import ServerArgs
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
Expand Down Expand Up @@ -54,7 +54,7 @@ def send_one_batch(base_url, num_prompts, batch_size):
]

# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
input_requests = [(p, 0, 512) for p in padded_prompts]
input_requests: List[DatasetRow] = [DatasetRow(p, 0, 512) for p in padded_prompts]

# We need to set some dummy values in order to call `benchmark` below.
args = SimpleNamespace(
Expand All @@ -69,6 +69,8 @@ def send_one_batch(base_url, num_prompts, batch_size):
random_output_len=None,
random_range_ratio=None,
output_file=None,
warmup_requests=1,
output_details=False,
)
set_global_args(args)
tokenizer = FakeTokenizer()
Expand Down Expand Up @@ -97,7 +99,9 @@ def send_one_batch(base_url, num_prompts, batch_size):

server_info = requests.get(base_url + "/get_server_info").json()
# We use 20% percentile instead of median on purpose
step_time = np.percentile(server_info["step_time_dict"][str(batch_size)], 20)
step_time = np.percentile(
server_info["internal_states"][0]["step_time_dict"][str(batch_size)], 20
)
Comment on lines +102 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change assumes that server_info["internal_states"] is always a list. However, this is only true when dp_size > 1. When dp_size == 1 (the default), server_info["internal_states"] will be a dictionary, not a list, and server_info["internal_states"][0] will raise a TypeError.

To make this script robust for both cases, you could check the type of server_info["internal_states"]. A better long-term solution would be to make the /get_server_info endpoint return a consistent format.

Here is a suggestion to make the script work for both cases:

    step_time = np.percentile(
        (
            server_info["internal_states"][0]
            if isinstance(server_info["internal_states"], list)
            else server_info["internal_states"]["internal_states"]
        )["step_time_dict"][str(batch_size)],
        20,
    )

speed = 1 / step_time * acc_length

return (
Expand Down
Loading