Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions benchmarks/mmmu/benchmark_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional

import torch
from eval_utils import (
add_common_benchmark_args,
get_message,
load_benchmark_config,
load_benchmark_dataset,
run_benchmark,
)
from transformers import AutoModelForImageTextToText, AutoProcessor, set_seed

from vllm.utils import FlexibleArgumentParser


def load_model_and_processor(model_name: str):
"""Load HuggingFace Vision-Language model and processor"""
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

model = None
for auto_class in [AutoModelForImageTextToText]:
try:
model = auto_class.from_pretrained(
model_name, torch_dtype="auto", trust_remote_code=True
)
print(f"Successfully loaded model with {auto_class.__name__}")
break
except Exception:
continue

if model is None:
raise ValueError(
f"Could not load model {model_name} with any available auto class"
)

model = model.eval().cuda()

return model, processor


def generate_response(
model,
processor,
prompt: str,
image,
max_tokens: int,
temperature: float,
top_p: float,
top_k: Optional[int],
do_sample: bool,
seed: int,
) -> str:
"""Generate response using HuggingFace Vision-Language model"""
# Set seed for reproducibility
set_seed(seed)

messages = get_message(prompt, image)

# Apply chat template
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

# Process inputs
inputs = processor(
text=[text],
images=[image] if image is not None else None,
return_tensors="pt",
padding=True,
)
inputs = inputs.to(model.device)

with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)

# Extract generated tokens (excluding input tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]

response = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]

return response.strip()


def hf_generate_func(model, processor, generation_params):
"""Create a generation function for HuggingFace VL models
that matches the common interface"""

def generate(prompts: list[str], images: Optional[list] = None) -> list[str]:
"""Generate responses using HuggingFace VL model"""
responses = []
if images is None:
images = [None] * len(prompts)

for prompt, image in zip(prompts, images):
response = generate_response(
model,
processor,
prompt,
image,
max_tokens=generation_params.max_tokens,
temperature=generation_params.temperature,
top_p=generation_params.top_p,
top_k=generation_params.top_k,
do_sample=generation_params.do_sample,
seed=generation_params.seed,
)
responses.append(response)
return responses

return generate


def main(args):
# Load model and processor
print(f"Loading model from {args.model}...")
model, processor = load_model_and_processor(args.model)

# Load evaluation config
config = load_benchmark_config(
args.config_path if hasattr(args, "config_path") else "eval_config.yaml"
)

# Load dataset
samples = load_benchmark_dataset(
split=args.split, subject=args.subject, max_samples=args.max_samples
)

# Create generation function
generate_func = hf_generate_func(model, processor, args)

# Model info for saving
model_info = {
"model": args.model,
"split": args.split,
"subject": args.subject,
"max_samples": args.max_samples,
}

# Run benchmark using common logic
results = run_benchmark(
samples=samples,
config=config,
args=args,
generate_func=generate_func,
batch_size=1, # HF processes one at a time
subject=args.subject,
output_path=args.output_path,
model_info=model_info,
)

return results


def invoke_main() -> None:
parser = FlexibleArgumentParser(
description="Benchmark HuggingFace models on MMMU dataset from HuggingFace Hub"
)

# Add common benchmark arguments
parser = add_common_benchmark_args(parser, framework="hf")

args = parser.parse_args()
main(args)


if __name__ == "__main__":
invoke_main() # pragma: no cover
164 changes: 164 additions & 0 deletions benchmarks/mmmu/benchmark_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from eval_utils import (
add_common_benchmark_args,
get_message,
load_benchmark_config,
load_benchmark_dataset,
run_benchmark,
)
from transformers import AutoTokenizer

from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser


def main(args: dict):
# get common args
seed = args.get("seed")
model_name = args.get("model")

# Pop sampling arguments
max_tokens = args.pop("max_tokens")
temperature = args.pop("temperature")
top_p = args.pop("top_p")
top_k = args.pop("top_k")

# Pop benchmark specific arguments
split = args.pop("split")
subject = args.pop("subject")
max_samples = args.pop("max_samples")
output_path = args.pop("output_path")
config_path = args.pop("config_path")
batch_size = args.pop("batch_size")

# Create an LLM with remaining args
print("Loading vLLM model...")
args["disable_mm_preprocessor_cache"] = True
llm = LLM(**args)

# Load tokenizer for chat template
print(f"Loading tokenizer from {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# Create sampling params using the LLM instance
sampling_params = llm.get_default_sampling_params()
if max_tokens is not None:
sampling_params.max_tokens = max_tokens
if temperature is not None:
sampling_params.temperature = temperature
if top_p is not None:
sampling_params.top_p = top_p
if top_k is not None:
sampling_params.top_k = top_k
if seed is not None:
sampling_params.seed = seed

# Store args for common benchmark function
class Args:
def __init__(self):
self.seed = seed
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k

benchmark_args = Args()

# Load evaluation config
config = load_benchmark_config(config_path)

# Load dataset
samples = load_benchmark_dataset(
split=split, subject=subject, max_samples=max_samples
)

# Model info for saving
model_info = {
"model": model_name,
"split": split,
"subject": subject,
"max_samples": max_samples,
"batch_size": batch_size,
}

# Create a generation function that matches the HF interface
def generate_with_params(prompts: list[str], images: list = None) -> list[str]:
"""
Generate responses for prompts with associated images.
Args:
prompts: List of prompt strings
images: List of image data (can be None for text-only)
Returns:
List of response strings
"""
# Prepare inputs for vLLM batch inference
inputs = []
if images is None:
images = [None] * len(prompts)

for prompt, image in zip(prompts, images):
messages = get_message(prompt, image)
try:
formatted_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception as e:
print(
f"Warning: Failed to apply chat template,\
using original prompt: {e}"
)
formatted_prompt = prompt

input_data = {"prompt": formatted_prompt}
if image is not None:
input_data["multi_modal_data"] = {"image": image}
inputs.append(input_data)

# Use our pre-configured sampling_params
outputs = llm.generate(inputs, sampling_params, use_tqdm=False)
responses = []
for output in outputs:
response = output.outputs[0].text.strip()
responses.append(response)
return responses

# Run benchmark
results = run_benchmark(
samples=samples,
config=config,
args=benchmark_args,
generate_func=generate_with_params,
batch_size=batch_size,
subject=subject,
output_path=output_path,
model_info=model_info,
)

return results


def create_parser():
parser = FlexibleArgumentParser(
description="Benchmark vLLM models on MMMU dataset using offline inference",
conflict_handler="resolve",
)

# Add engine args first (these provide base vLLM functionality)
EngineArgs.add_cli_args(parser)

# Add common benchmark arguments (these will override conflicting vLLM defaults)
parser = add_common_benchmark_args(parser, framework="vllm")

return parser


def invoke_main() -> None:
parser = create_parser()
args: dict = vars(parser.parse_args())
main(args)


if __name__ == "__main__":
invoke_main() # pragma: no cover
Loading
Loading