From 5f75b48a80494368f71d4a07f60f14a88b5f8a63 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Wed, 30 Apr 2025 03:44:01 +0000 Subject: [PATCH] support vlm benchmark profile --- benchmark/mmmu/bench_sglang.py | 65 ++++++++++++++++++++++++++++++++-- benchmark/mmmu/eval_utils.py | 8 +++++ 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index 3f91678ac61..55a7b1eaa2a 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -10,8 +10,14 @@ """ import argparse +import asyncio +import sys import time +import traceback +from dataclasses import dataclass, field +from typing import List +import aiohttp import openai from data_utils import save_json from eval_utils import ( @@ -25,8 +31,41 @@ from sglang.test.test_utils import add_common_sglang_args_and_parse +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60) -def eval_mmmu(args): + +@dataclass +class RequestFuncOutput: + generated_text: List[str] = field(default_factory=list) + prompt_len: List[int] = field(default_factory=list) + output_len: List[int] = field(default_factory=list) + latency: List[float] = field(default_factory=list) + ttft: List[float] = field(default_factory=list) + itl: List[float] = field(default_factory=list) # List of inter-token latencies + + success: bool = False + error: str = "" + + +async def async_request_profile(api_url: str) -> RequestFuncOutput: + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + output = RequestFuncOutput() + try: + async with session.post(url=api_url) as response: + if response.status == 200: + output.success = True + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + return output + + +async def eval_mmmu(args): eval_args = EvalArgs.from_cli_args(args) out_samples = dict() @@ -38,9 +77,22 @@ def eval_mmmu(args): answer_dict = {} # had to use an openai server, since SglImage doesn't support image data - client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1") + base_url = f"http://127.0.0.1:{args.port}" + client = openai.Client(api_key="sk", base_url=f"{base_url}/v1") start = time.time() + + if args.profile: + print("Starting profiler...") + profile_output = await async_request_profile( + api_url=f"{base_url}/start_profile" + ) + if profile_output.success: + print("Profiler started") + + if args.profile: + samples = samples[: args.profile_number] + for i, sample in enumerate(tqdm(samples)): prompt = sample["final_input_prompt"] prefix = prompt.split("<")[0] @@ -49,6 +101,7 @@ def eval_mmmu(args): assert image is not None image_path = sample["image_path"] # TODO: batch + response = client.chat.completions.create( model="default", messages=[ @@ -77,6 +130,12 @@ def eval_mmmu(args): response = response.choices[0].message.content process_result(response, sample, answer_dict, out_samples) + if args.profile: + print("Stopping profiler...") + profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile") + if profile_output.success: + print("Profiler stopped") + print(f"Benchmark time: {time.time() - start}") args.output_path = f"./val_sglang.json" @@ -89,4 +148,4 @@ def eval_mmmu(args): EvalArgs.add_cli_args(parser) args = add_common_sglang_args_and_parse(parser) args = parser.parse_args() - eval_mmmu(args) + asyncio.run(eval_mmmu(args)) diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index 59e2c49308a..1a7db250e08 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -33,6 +33,8 @@ class EvalArgs: prompt_format_file: str = "prompt_format.yaml" dataset_path: str = "MMMU/MMMU" extra_request_body: Optional[str] = None + profile: bool = False + profile_number: int = 5 @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -65,6 +67,12 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Append given JSON object to the request payload. You can use this to specify" "additional generate params like sampling params.", ) + parser.add_argument( + "--profile", action="store_true", help="enable mmmu profile" + ) + parser.add_argument( + "--profile-number", type=int, default=EvalArgs.profile_number + ) @classmethod def from_cli_args(cls, args: argparse.Namespace):