Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 63 additions & 21 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,7 @@ def get_dataset(args, tokenizer, model_id=None):
system_prompt_len=args.gsp_system_prompt_len,
question_len=args.gsp_question_len,
output_len=args.gsp_output_len,
range_ratio=getattr(args, "gsp_range_ratio", 1.0),
tokenizer=tokenizer,
args=args,
)
Expand Down Expand Up @@ -1235,6 +1236,14 @@ def sample_sharegpt_requests(
return filtered_dataset


def compute_random_lens(full_len: int, range_ratio: float, num: int):
return np.random.randint(
max(int(full_len * range_ratio), 1),
full_len + 1,
size=num,
)


def sample_random_requests(
input_len: int,
output_len: int,
Expand All @@ -1245,15 +1254,15 @@ def sample_random_requests(
random_sample: bool = True,
return_text: bool = True,
) -> List[DatasetRow]:
input_lens = np.random.randint(
max(int(input_len * range_ratio), 1),
input_len + 1,
size=num_prompts,
input_lens = compute_random_lens(
full_len=input_len,
range_ratio=range_ratio,
num=num_prompts,
)
output_lens = np.random.randint(
int(output_len * range_ratio),
output_len + 1,
size=num_prompts,
output_lens = compute_random_lens(
full_len=output_len,
range_ratio=range_ratio,
num=num_prompts,
)

if random_sample:
Expand Down Expand Up @@ -1476,11 +1485,15 @@ def sample_image_requests(
)

# Sample text lengths
input_lens = np.random.randint(
max(int(input_len * range_ratio), 1), input_len + 1, size=num_requests
input_lens = compute_random_lens(
full_len=input_len,
range_ratio=range_ratio,
num=num_requests,
)
output_lens = np.random.randint(
int(output_len * range_ratio), output_len + 1, size=num_requests
output_lens = compute_random_lens(
full_len=output_len,
range_ratio=range_ratio,
num=num_requests,
)

def _gen_random_image_data_uri(
Expand Down Expand Up @@ -1576,30 +1589,51 @@ def sample_generated_shared_prefix_requests(
system_prompt_len: int,
question_len: int,
output_len: int,
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace,
) -> List[DatasetRow]:
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
cache_path = get_gen_prefix_cache_path(args, tokenizer)

# Try to load from cache first
if cache_path.exists():
if cache_path.exists() and range_ratio == 1:
print(f"\nLoading cached generated input data from {cache_path}")
with open(cache_path, "rb") as f:
return pickle.load(f)

print("\nGenerating new input data...")
print(
f"\nGenerating new input data... "
f"({num_groups=}, {prompts_per_group}, {system_prompt_len=}, {question_len=}, {output_len=}, {range_ratio=})"
)

system_prompt_lens = compute_random_lens(
full_len=system_prompt_len,
range_ratio=range_ratio,
num=num_groups,
)
question_lens = compute_random_lens(
full_len=question_len,
range_ratio=range_ratio,
num=num_groups * prompts_per_group,
)
output_lens = compute_random_lens(
full_len=output_len,
range_ratio=range_ratio,
num=num_groups * prompts_per_group,
)
del system_prompt_len, question_len, output_len

# Generate system prompts for each group
system_prompts = []
for _ in range(num_groups):
system_prompt = gen_prompt(tokenizer, system_prompt_len)
for i in range(num_groups):
system_prompt = gen_prompt(tokenizer, system_prompt_lens[i].item())
system_prompts.append(system_prompt)

# Generate questions
questions = []
for _ in range(num_groups * prompts_per_group):
question = gen_prompt(tokenizer, question_len)
for i in range(num_groups * prompts_per_group):
question = gen_prompt(tokenizer, question_lens[i].item())
questions.append(question)

# Combine system prompts with questions
Expand All @@ -1612,19 +1646,20 @@ def sample_generated_shared_prefix_requests(
for prompt_idx in tqdm(
range(prompts_per_group), desc="Generating questions", leave=False
):
question = questions[group_idx * prompts_per_group + prompt_idx]
flat_index = group_idx * prompts_per_group + prompt_idx
question = questions[flat_index]
full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt))

input_requests.append(
DatasetRow(
prompt=full_prompt,
prompt_len=prompt_len,
output_len=output_len,
output_len=output_lens[flat_index].item(),
)
)
total_input_tokens += prompt_len
total_output_tokens += output_len
total_output_tokens += output_lens[flat_index].item()

# Shuffle questions
random.shuffle(input_requests)
Expand Down Expand Up @@ -2861,6 +2896,13 @@ def __call__(self, parser, namespace, values, option_string=None):
default=256,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
parser.add_argument(
"--gsp-range-ratio",
type=float,
# WARN: The default 1.0 is for backward compatibility, and is different from the default 0.0 for random dataset
default=1.0,
help="Range of sampled ratio of input/output length, used only for gsp dataset.",
)
mooncake_group = parser.add_argument_group("mooncake dataset arguments")
mooncake_group.add_argument(
"--mooncake-slowdown-factor",
Expand Down
Loading