diff --git a/scripts/playground/bench_speculative.py b/scripts/playground/bench_speculative.py index a8676556c27..f16ff4460a2 100644 --- a/scripts/playground/bench_speculative.py +++ b/scripts/playground/bench_speculative.py @@ -17,7 +17,7 @@ import numpy as np import requests -from sglang.bench_serving import benchmark, set_global_args +from sglang.bench_serving import DatasetRow, benchmark, set_global_args from sglang.srt.server_args import ServerArgs from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -54,7 +54,7 @@ def send_one_batch(base_url, num_prompts, batch_size): ] # format: (prompt, input_len, output len). We set input_len as a dummy value 0. - input_requests = [(p, 0, 512) for p in padded_prompts] + input_requests: List[DatasetRow] = [DatasetRow(p, 0, 512) for p in padded_prompts] # We need to set some dummy values in order to call `benchmark` below. args = SimpleNamespace( @@ -69,6 +69,8 @@ def send_one_batch(base_url, num_prompts, batch_size): random_output_len=None, random_range_ratio=None, output_file=None, + warmup_requests=1, + output_details=False, ) set_global_args(args) tokenizer = FakeTokenizer() @@ -97,7 +99,9 @@ def send_one_batch(base_url, num_prompts, batch_size): server_info = requests.get(base_url + "/get_server_info").json() # We use 20% percentile instead of median on purpose - step_time = np.percentile(server_info["step_time_dict"][str(batch_size)], 20) + step_time = np.percentile( + server_info["internal_states"][0]["step_time_dict"][str(batch_size)], 20 + ) speed = 1 / step_time * acc_length return (