From 5084ee1dd095ae3959c424bb4e59d382868812d6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Dec 2025 22:38:09 +0800 Subject: [PATCH] more --- python/sglang/bench_serving.py | 84 +++++++++++++++++++++++++--------- 1 file changed, 63 insertions(+), 21 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 1ab1ffdac465..3e6a5a91eda2 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -817,6 +817,7 @@ def get_dataset(args, tokenizer, model_id=None): system_prompt_len=args.gsp_system_prompt_len, question_len=args.gsp_question_len, output_len=args.gsp_output_len, + range_ratio=getattr(args, "gsp_range_ratio", 1.0), tokenizer=tokenizer, args=args, ) @@ -1235,6 +1236,14 @@ def sample_sharegpt_requests( return filtered_dataset +def compute_random_lens(full_len: int, range_ratio: float, num: int): + return np.random.randint( + max(int(full_len * range_ratio), 1), + full_len + 1, + size=num, + ) + + def sample_random_requests( input_len: int, output_len: int, @@ -1245,15 +1254,15 @@ def sample_random_requests( random_sample: bool = True, return_text: bool = True, ) -> List[DatasetRow]: - input_lens = np.random.randint( - max(int(input_len * range_ratio), 1), - input_len + 1, - size=num_prompts, + input_lens = compute_random_lens( + full_len=input_len, + range_ratio=range_ratio, + num=num_prompts, ) - output_lens = np.random.randint( - int(output_len * range_ratio), - output_len + 1, - size=num_prompts, + output_lens = compute_random_lens( + full_len=output_len, + range_ratio=range_ratio, + num=num_prompts, ) if random_sample: @@ -1476,11 +1485,15 @@ def sample_image_requests( ) # Sample text lengths - input_lens = np.random.randint( - max(int(input_len * range_ratio), 1), input_len + 1, size=num_requests + input_lens = compute_random_lens( + full_len=input_len, + range_ratio=range_ratio, + num=num_requests, ) - output_lens = np.random.randint( - int(output_len * range_ratio), output_len + 1, size=num_requests + output_lens = compute_random_lens( + full_len=output_len, + range_ratio=range_ratio, + num=num_requests, ) def _gen_random_image_data_uri( @@ -1576,6 +1589,7 @@ def sample_generated_shared_prefix_requests( system_prompt_len: int, question_len: int, output_len: int, + range_ratio: float, tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace, ) -> List[DatasetRow]: @@ -1583,23 +1597,43 @@ def sample_generated_shared_prefix_requests( cache_path = get_gen_prefix_cache_path(args, tokenizer) # Try to load from cache first - if cache_path.exists(): + if cache_path.exists() and range_ratio == 1: print(f"\nLoading cached generated input data from {cache_path}") with open(cache_path, "rb") as f: return pickle.load(f) - print("\nGenerating new input data...") + print( + f"\nGenerating new input data... " + f"({num_groups=}, {prompts_per_group}, {system_prompt_len=}, {question_len=}, {output_len=}, {range_ratio=})" + ) + + system_prompt_lens = compute_random_lens( + full_len=system_prompt_len, + range_ratio=range_ratio, + num=num_groups, + ) + question_lens = compute_random_lens( + full_len=question_len, + range_ratio=range_ratio, + num=num_groups * prompts_per_group, + ) + output_lens = compute_random_lens( + full_len=output_len, + range_ratio=range_ratio, + num=num_groups * prompts_per_group, + ) + del system_prompt_len, question_len, output_len # Generate system prompts for each group system_prompts = [] - for _ in range(num_groups): - system_prompt = gen_prompt(tokenizer, system_prompt_len) + for i in range(num_groups): + system_prompt = gen_prompt(tokenizer, system_prompt_lens[i].item()) system_prompts.append(system_prompt) # Generate questions questions = [] - for _ in range(num_groups * prompts_per_group): - question = gen_prompt(tokenizer, question_len) + for i in range(num_groups * prompts_per_group): + question = gen_prompt(tokenizer, question_lens[i].item()) questions.append(question) # Combine system prompts with questions @@ -1612,7 +1646,8 @@ def sample_generated_shared_prefix_requests( for prompt_idx in tqdm( range(prompts_per_group), desc="Generating questions", leave=False ): - question = questions[group_idx * prompts_per_group + prompt_idx] + flat_index = group_idx * prompts_per_group + prompt_idx + question = questions[flat_index] full_prompt = f"{system_prompt}\n\n{question}" prompt_len = len(tokenizer.encode(full_prompt)) @@ -1620,11 +1655,11 @@ def sample_generated_shared_prefix_requests( DatasetRow( prompt=full_prompt, prompt_len=prompt_len, - output_len=output_len, + output_len=output_lens[flat_index].item(), ) ) total_input_tokens += prompt_len - total_output_tokens += output_len + total_output_tokens += output_lens[flat_index].item() # Shuffle questions random.shuffle(input_requests) @@ -2861,6 +2896,13 @@ def __call__(self, parser, namespace, values, option_string=None): default=256, help="Target length in tokens for outputs in generated-shared-prefix dataset", ) + parser.add_argument( + "--gsp-range-ratio", + type=float, + # WARN: The default 1.0 is for backward compatibility, and is different from the default 0.0 for random dataset + default=1.0, + help="Range of sampled ratio of input/output length, used only for gsp dataset.", + ) mooncake_group = parser.add_argument_group("mooncake dataset arguments") mooncake_group.add_argument( "--mooncake-slowdown-factor",