diff --git a/python/sglang/test/send_one.py b/python/sglang/test/send_one.py index 1a9ec1dacfce..3c0e1f215adb 100644 --- a/python/sglang/test/send_one.py +++ b/python/sglang/test/send_one.py @@ -22,6 +22,7 @@ class BenchArgs: host: str = "localhost" port: int = 30000 batch_size: int = 1 + different_prompts: bool = False temperature: float = 0.0 max_new_tokens: int = 512 frequency_penalty: float = 0.0 @@ -44,6 +45,11 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument("--host", type=str, default=BenchArgs.host) parser.add_argument("--port", type=int, default=BenchArgs.port) parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size) + parser.add_argument( + "--different-prompts", + action="store_true", + default=BenchArgs.different_prompts, + ) parser.add_argument("--temperature", type=float, default=BenchArgs.temperature) parser.add_argument( "--profile-name-prefix", type=str, default=BenchArgs.profile_name_prefix @@ -75,7 +81,7 @@ def from_cli_args(cls, args: argparse.Namespace): return cls(**{attr: getattr(args, attr) for attr in attrs}) -def send_one_prompt(args): +def send_one_prompt(args: BenchArgs): base_url = f"http://{args.host}:{args.port}" if args.image: @@ -110,7 +116,10 @@ def send_one_prompt(args): json_schema = None if args.batch_size > 1: - prompt = [prompt] * args.batch_size + if not args.different_prompts: + prompt = [prompt] * args.batch_size + else: + prompt = [f"Test case {i+1}: " + prompt for i in range(args.batch_size)] json_data = { "text": prompt, @@ -193,6 +202,6 @@ def send_one_prompt(args): if __name__ == "__main__": parser = argparse.ArgumentParser() BenchArgs.add_cli_args(parser) - args = parser.parse_args() + args = BenchArgs.from_cli_args(parser.parse_args()) send_one_prompt(args)