Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions benchmark/mmmu/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@
"""

import argparse
import asyncio
import sys
import time
import traceback
from dataclasses import dataclass, field
from typing import List

import aiohttp
import openai
from data_utils import save_json
from eval_utils import (
Expand All @@ -25,8 +31,41 @@

from sglang.test.test_utils import add_common_sglang_args_and_parse

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)

def eval_mmmu(args):

@dataclass
class RequestFuncOutput:
generated_text: List[str] = field(default_factory=list)
prompt_len: List[int] = field(default_factory=list)
output_len: List[int] = field(default_factory=list)
latency: List[float] = field(default_factory=list)
ttft: List[float] = field(default_factory=list)
itl: List[float] = field(default_factory=list) # List of inter-token latencies

success: bool = False
error: str = ""


async def async_request_profile(api_url: str) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
output = RequestFuncOutput()
try:
async with session.post(url=api_url) as response:
if response.status == 200:
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))

return output


async def eval_mmmu(args):
eval_args = EvalArgs.from_cli_args(args)

out_samples = dict()
Expand All @@ -38,9 +77,22 @@ def eval_mmmu(args):
answer_dict = {}

# had to use an openai server, since SglImage doesn't support image data
client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1")
base_url = f"http://127.0.0.1:{args.port}"
client = openai.Client(api_key="sk", base_url=f"{base_url}/v1")

start = time.time()

if args.profile:
print("Starting profiler...")
profile_output = await async_request_profile(
api_url=f"{base_url}/start_profile"
)
if profile_output.success:
print("Profiler started")

if args.profile:
samples = samples[: args.profile_number]

for i, sample in enumerate(tqdm(samples)):
prompt = sample["final_input_prompt"]
prefix = prompt.split("<")[0]
Expand All @@ -49,6 +101,7 @@ def eval_mmmu(args):
assert image is not None
image_path = sample["image_path"]
# TODO: batch

response = client.chat.completions.create(
model="default",
messages=[
Expand Down Expand Up @@ -77,6 +130,12 @@ def eval_mmmu(args):
response = response.choices[0].message.content
process_result(response, sample, answer_dict, out_samples)

if args.profile:
print("Stopping profiler...")
profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile")
if profile_output.success:
print("Profiler stopped")

print(f"Benchmark time: {time.time() - start}")

args.output_path = f"./val_sglang.json"
Expand All @@ -89,4 +148,4 @@ def eval_mmmu(args):
EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser)
args = parser.parse_args()
eval_mmmu(args)
asyncio.run(eval_mmmu(args))
8 changes: 8 additions & 0 deletions benchmark/mmmu/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class EvalArgs:
prompt_format_file: str = "prompt_format.yaml"
dataset_path: str = "MMMU/MMMU"
extra_request_body: Optional[str] = None
profile: bool = False
profile_number: int = 5

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -65,6 +67,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
parser.add_argument(
"--profile", action="store_true", help="enable mmmu profile"
)
parser.add_argument(
"--profile-number", type=int, default=EvalArgs.profile_number
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down
Loading