diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index e64063047663..066b8fe83438 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -89,6 +89,7 @@ class RequestFuncOutput: tpot: float = 0.0 # avg next-token latencies prompt_len: int = 0 error: str = "" + start_time: float = 0.0 async def async_request_openai_completions( @@ -140,6 +141,7 @@ async def async_request_openai_completions( generated_text = "" st = time.perf_counter() + output.start_time = st most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, @@ -272,6 +274,7 @@ async def async_request_openai_chat_completions( generated_text = "" ttft = 0.0 st = time.perf_counter() + output.start_time = st most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, @@ -396,6 +399,7 @@ def to_bytes(y, sr): generated_text = "" ttft = 0.0 st = time.perf_counter() + output.start_time = st most_recent_timestamp = st try: async with session.post(url=api_url, @@ -475,6 +479,7 @@ async def async_request_openai_embeddings( output = RequestFuncOutput() st = time.perf_counter() + output.start_time = st try: async with session.post( url=api_url, diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 1aeef0fd5bd8..d8784340eba1 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -18,9 +18,11 @@ import argparse import asyncio import gc +import importlib.util import json import os import random +import shutil import time import warnings from collections.abc import AsyncGenerator, Iterable @@ -46,6 +48,9 @@ MILLISECONDS_TO_SECONDS_CONVERSION = 1000 +TERM_PLOTLIB_AVAILABLE = ((importlib.util.find_spec("termplotlib") is not None) + and (shutil.which("gnuplot") is not None)) + class TaskType(Enum): GENERATION = "generation" @@ -80,18 +85,23 @@ class BenchmarkMetrics: median_e2el_ms: float std_e2el_ms: float percentiles_e2el_ms: list[tuple[float, float]] + # Max output tokens per second and concurrent requests at that peak + max_output_tokens_per_s: float + max_concurrent_requests: int + @dataclass class EmbedBenchmarkMetrics: completed: int total_input: int request_throughput: float - total_token_throughput :float + total_token_throughput: float mean_e2el_ms: float std_e2el_ms: float median_e2el_ms: float percentiles_e2el_ms: float + def _get_current_request_rate( ramp_up_strategy: Optional[Literal["linear", "exponential"]], ramp_up_start_rps: Optional[int], @@ -150,8 +160,8 @@ async def get_request( assert burstiness > 0, ( f"A positive burstiness factor is expected, but given {burstiness}.") # Convert to list to get length for ramp-up calculations - if isinstance(input_requests, Iterable) and not isinstance( - input_requests, list): + if isinstance(input_requests, + Iterable) and not isinstance(input_requests, list): input_requests = list(input_requests) total_requests = len(input_requests) @@ -161,12 +171,9 @@ async def get_request( request_rates = [] delay_ts = [] for request_index, request in enumerate(input_requests): - current_request_rate = _get_current_request_rate(ramp_up_strategy, - ramp_up_start_rps, - ramp_up_end_rps, - request_index, - total_requests, - request_rate) + current_request_rate = _get_current_request_rate( + ramp_up_strategy, ramp_up_start_rps, ramp_up_end_rps, + request_index, total_requests, request_rate) request_rates.append(current_request_rate) if current_request_rate == float("inf"): delay_ts.append(0) @@ -206,10 +213,8 @@ async def get_request( def calculate_metrics_for_embeddings( - outputs: list[RequestFuncOutput], - dur_s: float, - selected_percentiles: list[float] -) -> EmbedBenchmarkMetrics: + outputs: list[RequestFuncOutput], dur_s: float, + selected_percentiles: list[float]) -> EmbedBenchmarkMetrics: """Calculate the metrics for the embedding requests. Args: @@ -242,10 +247,8 @@ def calculate_metrics_for_embeddings( mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[ - (p, np.percentile(e2els or 0, p) * 1000) - for p in selected_percentiles - ], + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], ) return metrics @@ -336,6 +339,67 @@ def calculate_metrics( "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", stacklevel=2) + + # Calculate max output tokens per second metric + max_output_tokens_per_s = 0.0 + max_concurrent_requests = 0 + + # Find the time range across all successful requests + successful_outputs = [output for output in outputs if output.success] + if successful_outputs: + min_start_time = min(output.start_time + for output in successful_outputs) + max_end_time = max(output.start_time + output.latency + for output in successful_outputs) + + # Create second buckets (ceiling to ensure we capture all time) + duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1 + tokens_per_second = np.zeros(duration_seconds) + concurrent_requests_per_second = np.zeros(duration_seconds) + + for i, output in enumerate(successful_outputs): + # Calculate token generation timestamp using + # start_time, ttft, and itl + token_times = [output.start_time + output.ttft] + current_time = token_times[0] + for itl_value in output.itl: + current_time += itl_value + token_times.append(current_time) + + # Add tokens to second buckets + for token_time in token_times: + second_bucket = int(token_time - min_start_time) + if 0 <= second_bucket < duration_seconds: + tokens_per_second[second_bucket] += 1 + + # Track concurrent requests for each second this request was active + request_start_second = int(output.start_time - min_start_time) + request_end_second = int((output.start_time + output.latency) - + min_start_time) + + for second in range(request_start_second, request_end_second + 1): + concurrent_requests_per_second[second] += 1 + + # Find the maximum tokens per second and corresponding + # concurrent requests + if len(tokens_per_second) > 0: + max_output_tokens_per_s = float(np.max(tokens_per_second)) + max_concurrent_requests = int( + np.max(concurrent_requests_per_second)) + + if TERM_PLOTLIB_AVAILABLE: + import termplotlib as tpl + fig = tpl.figure() + fig.plot(np.arange(len(tokens_per_second)), + tokens_per_second, + title="Output tokens per second") + fig.plot(np.arange(len(concurrent_requests_per_second)), + concurrent_requests_per_second, + title="Concurrent requests per second") + fig.show() + else: + print("tip: install termplotlib and gnuplot to plot the metrics") + metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -365,6 +429,8 @@ def calculate_metrics( median_e2el_ms=np.median(e2els or 0) * 1000, percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles], + max_output_tokens_per_s=max_output_tokens_per_s, + max_concurrent_requests=max_concurrent_requests, ) return metrics, actual_output_lens @@ -396,11 +462,8 @@ async def benchmark( ramp_up_end_rps: Optional[int] = None, ready_check_timeout_sec: int = 600, ): - task_type = ( - TaskType.EMBEDDING - if api_url.endswith("/v1/embeddings") - else TaskType.GENERATION - ) + task_type = (TaskType.EMBEDDING if api_url.endswith("/v1/embeddings") else + TaskType.GENERATION) if endpoint_type in ASYNC_REQUEST_FUNCS: if task_type == TaskType.EMBEDDING: request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"] @@ -435,14 +498,10 @@ async def benchmark( input_requests[0].multi_modal_data, ) - assert ( - test_mm_content is None - or isinstance(test_mm_content, dict) - or ( - isinstance(test_mm_content, list) - and all(isinstance(item, dict) for item in test_mm_content) - ) - ), "multi_modal_data must be a dict or list[dict]" + assert (test_mm_content is None or isinstance(test_mm_content, dict) + or (isinstance(test_mm_content, list) + and all(isinstance(item, dict) for item in test_mm_content)) + ), "multi_modal_data must be a dict or list[dict]" test_input = RequestFuncInput( model=model_id, model_name=model_name, @@ -488,13 +547,13 @@ async def benchmark( ignore_eos=ignore_eos, extra_headers=extra_headers, extra_body=extra_body) - profile_output = await request_func( - request_func_input=profile_input, session=session) + profile_output = await request_func(request_func_input=profile_input, + session=session) if profile_output.success: print("Profiler started") - distribution = ("Poisson process" if burstiness == 1.0 - else "Gamma distribution") + distribution = ("Poisson process" + if burstiness == 1.0 else "Gamma distribution") if ramp_up_strategy is not None: print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") @@ -562,18 +621,20 @@ async def limited_request_func(request_func_input, session, pbar): req_lora_module = next(lora_modules) req_model_id, req_model_name = req_lora_module, req_lora_module - request_func_input = RequestFuncInput(model=req_model_id, - model_name=req_model_name, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - multi_modal_content=mm_content, - ignore_eos=ignore_eos, - extra_headers=extra_headers, - extra_body=extra_body, - request_id=request_id,) + request_func_input = RequestFuncInput( + model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos, + extra_headers=extra_headers, + extra_body=extra_body, + request_id=request_id, + ) tasks.append( asyncio.create_task( limited_request_func(request_func_input=request_func_input, @@ -615,19 +676,21 @@ async def limited_request_func(request_func_input, session, pbar): benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) if isinstance(metrics, BenchmarkMetrics): - print("{:<40} {:<10}".format( - "Total generated tokens:", metrics.total_output)) + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) if goodput_config_dict: print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput)) if isinstance(metrics, BenchmarkMetrics): - print( - "{:<40} {:<10.2f}".format( - "Output token throughput (tok/s):", metrics.output_throughput - ) - ) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) + print("{:<40} {:<10.2f}".format( + "Peak output token throughput (tok/s):", + metrics.max_output_tokens_per_s)) + print("{:<40} {:<10.2f}".format("Peak concurrent requests:", + metrics.max_concurrent_requests)) print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput)) @@ -648,6 +711,8 @@ async def limited_request_func(request_func_input, session, pbar): "itls": [output.itl for output in outputs], "generated_texts": [output.generated_text for output in outputs], "errors": [output.error for output in outputs], + "max_output_tokens_per_s": metrics.max_output_tokens_per_s, + "max_concurrent_requests": metrics.max_concurrent_requests, } else: result = { @@ -697,8 +762,8 @@ def process_one_metric( if task_type == TaskType.GENERATION: process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric( - "tpot", "TPOT", "Time per Output Token (excl. 1st token)") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") @@ -714,8 +779,8 @@ def process_one_metric( output_len=test_output_len, logprobs=logprobs, ) - profile_output = await request_func( - request_func_input=profile_input, session=session) + profile_output = await request_func(request_func_input=profile_input, + session=session) if profile_output.success: print("Profiler stopped") @@ -851,7 +916,8 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--tokenizer", type=str, - help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help= + "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( @@ -982,7 +1048,6 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Specify the prefix of request id.", ) - sampling_group = parser.add_argument_group("sampling parameters") sampling_group.add_argument( "--top-p", @@ -1047,8 +1112,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The ramp-up strategy. This would be used to " "ramp up the request rate from initial RPS to final " "RPS rate (specified by --ramp-up-start-rps and " - "--ramp-up-end-rps.) over the duration of the benchmark." - ) + "--ramp-up-end-rps.) over the duration of the benchmark.") parser.add_argument( "--ramp-up-start-rps", type=int, @@ -1087,13 +1151,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: raise ValueError( "When using ramp-up, do not specify --request-rate. " "The request rate will be controlled by ramp-up parameters. " - "Please remove the --request-rate argument." - ) + "Please remove the --request-rate argument.") if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: raise ValueError( "When using --ramp-up-strategy, both --ramp-up-start-rps and " - "--ramp-up-end-rps must be specified" - ) + "--ramp-up-end-rps must be specified") if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: raise ValueError("Ramp-up start and end RPS must be non-negative") if args.ramp_up_start_rps > args.ramp_up_end_rps: @@ -1127,8 +1189,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: headers[kvstring[0].strip()] = kvstring[1].strip() else: raise ValueError( - "Invalid header format. Please use KEY=VALUE format." - ) + "Invalid header format. Please use KEY=VALUE format.") tokenizer = get_tokenizer(tokenizer_id, tokenizer_mode=tokenizer_mode, @@ -1215,8 +1276,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: result_json[kvstring[0].strip()] = kvstring[1].strip() else: raise ValueError( - "Invalid metadata format. Please use KEY=VALUE format." - ) + "Invalid metadata format. Please use KEY=VALUE format.") # Traffic result_json["request_rate"] = (args.request_rate if args.request_rate