Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,14 +490,15 @@ def get_dataset(args, tokenizer):
prompt_suffix=args.prompt_suffix,
apply_chat_template=args.apply_chat_template,
)
elif args.dataset_name == "random":
elif args.dataset_name.startswith("random"):
input_requests = sample_random_requests(
input_len=args.random_input_len,
output_len=args.random_output_len,
num_prompts=args.num_prompts,
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
dataset_path=args.dataset_path,
random_sample=args.dataset_name == "random",
)
elif args.dataset_name == "generated-shared-prefix":
input_requests = sample_generated_shared_prefix_requests(
Expand Down Expand Up @@ -687,6 +688,7 @@ def sample_random_requests(
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
dataset_path: str,
random_sample: bool = True,
) -> List[Tuple[str, int, int]]:

input_lens = np.random.randint(
Expand All @@ -700,11 +702,15 @@ def sample_random_requests(
size=num_prompts,
)

if True:
if random_sample:
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens

# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
print(
"If you do not want to randomly sample from a dataset,"
" please use --dataset-name random-ids."
)
dataset_path = download_and_cache_file(SHAREGPT_URL)

# Load the dataset.
Expand Down Expand Up @@ -1223,7 +1229,7 @@ async def limited_request_func(request_func_input, pbar):
output_file_name = args.output_file
else:
now = datetime.now().strftime("%m%d")
if args.dataset_name == "random":
if args.dataset_name.startswith("random"):
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
else:
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
Expand Down Expand Up @@ -1442,7 +1448,7 @@ def __call__(self, parser, namespace, values, option_string=None):
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "random", "generated-shared-prefix"],
choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
Expand Down
Loading