diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index 457d120d95b..294d3f688ef 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -60,6 +60,8 @@ class BenchArgs: skip_warmup: bool = False do_not_exit: bool = False prompt_suffix: str = "" + return_logprob: bool = False + logprob_start_len: int = -1 @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -187,6 +189,17 @@ def add_cli_args(parser: argparse.ArgumentParser): default="", help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.", ) + parser.add_argument( + "--return-logprob", + action="store_true", + help="Enable returning log probabilities.", + ) + parser.add_argument( + "--logprob-start-len", + type=int, + default=-1, + help="Start length for logprob. -1 means only return logprobs for output tokens (default). 0 means return logprobs for all tokens including input.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -201,6 +214,8 @@ def throughput_test_once( ignore_eos: bool, extra_request_body: Dict, profile: bool, + return_logprob: bool = False, + logprob_start_len: int = -1, ): measurement_results = { "backend": backend_name, @@ -233,7 +248,12 @@ def throughput_test_once( backend.start_profile() st = time.perf_counter() - gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params) + gen_out = backend.generate( + prompt=prompt, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + ) latency = time.perf_counter() - st if profile: @@ -355,6 +375,8 @@ def throughput_test( ignore_eos=not bench_args.disable_ignore_eos, extra_request_body=extra_request_body, profile=False, + return_logprob=bench_args.return_logprob, + logprob_start_len=bench_args.logprob_start_len, ) time.sleep(0.5) @@ -366,6 +388,8 @@ def throughput_test( ignore_eos=not bench_args.disable_ignore_eos, extra_request_body=extra_request_body, profile=bench_args.profile, + return_logprob=bench_args.return_logprob, + logprob_start_len=bench_args.logprob_start_len, ) backend.shutdown()