diff --git a/benchmarks/mmmu/benchmark_hf.py b/benchmarks/mmmu/benchmark_hf.py
new file mode 100644
index 000000000000..7c43c4e12c8d
--- /dev/null
+++ b/benchmarks/mmmu/benchmark_hf.py
@@ -0,0 +1,183 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Optional
+
+import torch
+from eval_utils import (
+ add_common_benchmark_args,
+ get_message,
+ load_benchmark_config,
+ load_benchmark_dataset,
+ run_benchmark,
+)
+from transformers import AutoModelForImageTextToText, AutoProcessor, set_seed
+
+from vllm.utils import FlexibleArgumentParser
+
+
+def load_model_and_processor(model_name: str):
+ """Load HuggingFace Vision-Language model and processor"""
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
+
+ model = None
+ for auto_class in [AutoModelForImageTextToText]:
+ try:
+ model = auto_class.from_pretrained(
+ model_name, torch_dtype="auto", trust_remote_code=True
+ )
+ print(f"Successfully loaded model with {auto_class.__name__}")
+ break
+ except Exception:
+ continue
+
+ if model is None:
+ raise ValueError(
+ f"Could not load model {model_name} with any available auto class"
+ )
+
+ model = model.eval().cuda()
+
+ return model, processor
+
+
+def generate_response(
+ model,
+ processor,
+ prompt: str,
+ image,
+ max_tokens: int,
+ temperature: float,
+ top_p: float,
+ top_k: Optional[int],
+ do_sample: bool,
+ seed: int,
+) -> str:
+ """Generate response using HuggingFace Vision-Language model"""
+ # Set seed for reproducibility
+ set_seed(seed)
+
+ messages = get_message(prompt, image)
+
+ # Apply chat template
+ text = processor.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ # Process inputs
+ inputs = processor(
+ text=[text],
+ images=[image] if image is not None else None,
+ return_tensors="pt",
+ padding=True,
+ )
+ inputs = inputs.to(model.device)
+
+ with torch.no_grad():
+ generated_ids = model.generate(
+ **inputs,
+ max_new_tokens=max_tokens,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ )
+
+ # Extract generated tokens (excluding input tokens)
+ generated_ids_trimmed = [
+ out_ids[len(in_ids) :]
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ ]
+
+ response = processor.batch_decode(
+ generated_ids_trimmed,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )[0]
+
+ return response.strip()
+
+
+def hf_generate_func(model, processor, generation_params):
+ """Create a generation function for HuggingFace VL models
+ that matches the common interface"""
+
+ def generate(prompts: list[str], images: Optional[list] = None) -> list[str]:
+ """Generate responses using HuggingFace VL model"""
+ responses = []
+ if images is None:
+ images = [None] * len(prompts)
+
+ for prompt, image in zip(prompts, images):
+ response = generate_response(
+ model,
+ processor,
+ prompt,
+ image,
+ max_tokens=generation_params.max_tokens,
+ temperature=generation_params.temperature,
+ top_p=generation_params.top_p,
+ top_k=generation_params.top_k,
+ do_sample=generation_params.do_sample,
+ seed=generation_params.seed,
+ )
+ responses.append(response)
+ return responses
+
+ return generate
+
+
+def main(args):
+ # Load model and processor
+ print(f"Loading model from {args.model}...")
+ model, processor = load_model_and_processor(args.model)
+
+ # Load evaluation config
+ config = load_benchmark_config(
+ args.config_path if hasattr(args, "config_path") else "eval_config.yaml"
+ )
+
+ # Load dataset
+ samples = load_benchmark_dataset(
+ split=args.split, subject=args.subject, max_samples=args.max_samples
+ )
+
+ # Create generation function
+ generate_func = hf_generate_func(model, processor, args)
+
+ # Model info for saving
+ model_info = {
+ "model": args.model,
+ "split": args.split,
+ "subject": args.subject,
+ "max_samples": args.max_samples,
+ }
+
+ # Run benchmark using common logic
+ results = run_benchmark(
+ samples=samples,
+ config=config,
+ args=args,
+ generate_func=generate_func,
+ batch_size=1, # HF processes one at a time
+ subject=args.subject,
+ output_path=args.output_path,
+ model_info=model_info,
+ )
+
+ return results
+
+
+def invoke_main() -> None:
+ parser = FlexibleArgumentParser(
+ description="Benchmark HuggingFace models on MMMU dataset from HuggingFace Hub"
+ )
+
+ # Add common benchmark arguments
+ parser = add_common_benchmark_args(parser, framework="hf")
+
+ args = parser.parse_args()
+ main(args)
+
+
+if __name__ == "__main__":
+ invoke_main() # pragma: no cover
diff --git a/benchmarks/mmmu/benchmark_vllm.py b/benchmarks/mmmu/benchmark_vllm.py
new file mode 100644
index 000000000000..02621864e559
--- /dev/null
+++ b/benchmarks/mmmu/benchmark_vllm.py
@@ -0,0 +1,164 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from eval_utils import (
+ add_common_benchmark_args,
+ get_message,
+ load_benchmark_config,
+ load_benchmark_dataset,
+ run_benchmark,
+)
+from transformers import AutoTokenizer
+
+from vllm import LLM, EngineArgs
+from vllm.utils import FlexibleArgumentParser
+
+
+def main(args: dict):
+ # get common args
+ seed = args.get("seed")
+ model_name = args.get("model")
+
+ # Pop sampling arguments
+ max_tokens = args.pop("max_tokens")
+ temperature = args.pop("temperature")
+ top_p = args.pop("top_p")
+ top_k = args.pop("top_k")
+
+ # Pop benchmark specific arguments
+ split = args.pop("split")
+ subject = args.pop("subject")
+ max_samples = args.pop("max_samples")
+ output_path = args.pop("output_path")
+ config_path = args.pop("config_path")
+ batch_size = args.pop("batch_size")
+
+ # Create an LLM with remaining args
+ print("Loading vLLM model...")
+ args["disable_mm_preprocessor_cache"] = True
+ llm = LLM(**args)
+
+ # Load tokenizer for chat template
+ print(f"Loading tokenizer from {model_name}...")
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+
+ # Create sampling params using the LLM instance
+ sampling_params = llm.get_default_sampling_params()
+ if max_tokens is not None:
+ sampling_params.max_tokens = max_tokens
+ if temperature is not None:
+ sampling_params.temperature = temperature
+ if top_p is not None:
+ sampling_params.top_p = top_p
+ if top_k is not None:
+ sampling_params.top_k = top_k
+ if seed is not None:
+ sampling_params.seed = seed
+
+ # Store args for common benchmark function
+ class Args:
+ def __init__(self):
+ self.seed = seed
+ self.max_tokens = max_tokens
+ self.temperature = temperature
+ self.top_p = top_p
+ self.top_k = top_k
+
+ benchmark_args = Args()
+
+ # Load evaluation config
+ config = load_benchmark_config(config_path)
+
+ # Load dataset
+ samples = load_benchmark_dataset(
+ split=split, subject=subject, max_samples=max_samples
+ )
+
+ # Model info for saving
+ model_info = {
+ "model": model_name,
+ "split": split,
+ "subject": subject,
+ "max_samples": max_samples,
+ "batch_size": batch_size,
+ }
+
+ # Create a generation function that matches the HF interface
+ def generate_with_params(prompts: list[str], images: list = None) -> list[str]:
+ """
+ Generate responses for prompts with associated images.
+ Args:
+ prompts: List of prompt strings
+ images: List of image data (can be None for text-only)
+ Returns:
+ List of response strings
+ """
+ # Prepare inputs for vLLM batch inference
+ inputs = []
+ if images is None:
+ images = [None] * len(prompts)
+
+ for prompt, image in zip(prompts, images):
+ messages = get_message(prompt, image)
+ try:
+ formatted_prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ except Exception as e:
+ print(
+ f"Warning: Failed to apply chat template,\
+ using original prompt: {e}"
+ )
+ formatted_prompt = prompt
+
+ input_data = {"prompt": formatted_prompt}
+ if image is not None:
+ input_data["multi_modal_data"] = {"image": image}
+ inputs.append(input_data)
+
+ # Use our pre-configured sampling_params
+ outputs = llm.generate(inputs, sampling_params, use_tqdm=False)
+ responses = []
+ for output in outputs:
+ response = output.outputs[0].text.strip()
+ responses.append(response)
+ return responses
+
+ # Run benchmark
+ results = run_benchmark(
+ samples=samples,
+ config=config,
+ args=benchmark_args,
+ generate_func=generate_with_params,
+ batch_size=batch_size,
+ subject=subject,
+ output_path=output_path,
+ model_info=model_info,
+ )
+
+ return results
+
+
+def create_parser():
+ parser = FlexibleArgumentParser(
+ description="Benchmark vLLM models on MMMU dataset using offline inference",
+ conflict_handler="resolve",
+ )
+
+ # Add engine args first (these provide base vLLM functionality)
+ EngineArgs.add_cli_args(parser)
+
+ # Add common benchmark arguments (these will override conflicting vLLM defaults)
+ parser = add_common_benchmark_args(parser, framework="vllm")
+
+ return parser
+
+
+def invoke_main() -> None:
+ parser = create_parser()
+ args: dict = vars(parser.parse_args())
+ main(args)
+
+
+if __name__ == "__main__":
+ invoke_main() # pragma: no cover
diff --git a/benchmarks/mmmu/data_utils.py b/benchmarks/mmmu/data_utils.py
new file mode 100644
index 000000000000..38378f30b4e9
--- /dev/null
+++ b/benchmarks/mmmu/data_utils.py
@@ -0,0 +1,234 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Adapted from
+# https://github.com/MMMU-Benchmark/MMMU
+
+"""Utils for data load, save, and process (e.g., prompt construction)"""
+
+import ast
+import json
+import os
+import re
+from typing import Optional
+
+import yaml
+from datasets import load_dataset
+
+DOMAIN_CAT2SUB_CAT = {
+ "Art and Design": ["Art", "Art_Theory", "Design", "Music"],
+ "Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
+ "Science": [
+ "Biology",
+ "Chemistry",
+ "Geography",
+ "Math",
+ "Physics",
+ ],
+ "Health and Medicine": [
+ "Basic_Medical_Science",
+ "Clinical_Medicine",
+ "Diagnostics_and_Laboratory_Medicine",
+ "Pharmacy",
+ "Public_Health",
+ ],
+ "Humanities and Social Science": [
+ "History",
+ "Literature",
+ "Sociology",
+ "Psychology",
+ ],
+ "Tech and Engineering": [
+ "Agriculture",
+ "Architecture_and_Engineering",
+ "Computer_Science",
+ "Electronics",
+ "Energy_and_Power",
+ "Materials",
+ "Mechanical_Engineering",
+ ],
+}
+
+
+CAT_SHORT2LONG = {
+ "acc": "Accounting",
+ "agri": "Agriculture",
+ "arch": "Architecture_and_Engineering",
+ "art": "Art",
+ "art_theory": "Art_Theory",
+ "bas_med": "Basic_Medical_Science",
+ "bio": "Biology",
+ "chem": "Chemistry",
+ "cli_med": "Clinical_Medicine",
+ "cs": "Computer_Science",
+ "design": "Design",
+ "diag_med": "Diagnostics_and_Laboratory_Medicine",
+ "econ": "Economics",
+ "elec": "Electronics",
+ "ep": "Energy_and_Power",
+ "fin": "Finance",
+ "geo": "Geography",
+ "his": "History",
+ "liter": "Literature",
+ "manage": "Manage",
+ "mark": "Marketing",
+ "mate": "Materials",
+ "math": "Math",
+ "mech": "Mechanical_Engineering",
+ "music": "Music",
+ "phar": "Pharmacy",
+ "phys": "Physics",
+ "psy": "Psychology",
+ "pub_health": "Public_Health",
+ "socio": "Sociology",
+}
+
+
+def load_mmmu_dataset(subset: str = "validation", subject: Optional[str] = None):
+ """Load MMMU dataset from HuggingFace Hub"""
+ available_subjects = list(CAT_SHORT2LONG.values()) if subject is None else [subject]
+ datasets_dict = {}
+ for subj in set(available_subjects):
+ subj_dataset = load_dataset("MMMU/MMMU", subj, split=subset)
+ datasets_dict[subj] = subj_dataset
+ return datasets_dict
+
+
+def get_multi_choice_info(options):
+ """
+ Given the list of options for multiple choice question
+ Return the index2ans and all_choices
+ """
+
+ start_chr = "A"
+ all_choices = []
+ index2ans = {}
+ for i, option in enumerate(options):
+ index2ans[chr(ord(start_chr) + i)] = option
+ all_choices.append(chr(ord(start_chr) + i))
+
+ return index2ans, all_choices
+
+
+def load_yaml(file_path):
+ with open(file_path) as stream:
+ try:
+ yaml_dict = yaml.safe_load(stream)
+ except yaml.YAMLError as exc:
+ print(exc)
+
+ return yaml_dict
+
+
+def parse_img_path(text):
+ matches = re.findall("
", text)
+ return matches
+
+
+def process_single_sample(data):
+ question = data["question"]
+ o_imgs_paths = []
+ for option in data["options"]:
+ current_o_imgs_paths = parse_img_path(option)
+ for img_path in current_o_imgs_paths:
+ o_imgs_paths.append(img_path)
+
+ if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
+ return {
+ "id": data["id"],
+ "question": question,
+ "options": data["options"],
+ "answer": data["answer"],
+ "image": None,
+ "question_type": data["question_type"],
+ }
+ else:
+ return {
+ "id": data["id"],
+ "question": question,
+ "options": data["options"],
+ "answer": data["answer"],
+ "image": data["image_1"],
+ "question_type": data["question_type"],
+ }
+
+
+# DATA SAVING
+def save_json(filename, ds):
+ with open(filename, "w") as f:
+ json.dump(ds, f, indent=4)
+
+
+def save_jsonl(filename, data):
+ """
+ Save a dictionary of data to a JSON Lines file with the filename as
+ key and caption as value.
+
+ Args:
+ filename (str): The path to the file where the data should be saved.
+ data (dict): The dictionary containing the data to save where key
+ is the image path and value is the caption.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for img_path, caption in data.items():
+ # Extract the base filename without the extension
+ base_filename = os.path.basename(img_path)
+ # Create a JSON object with the filename as the key and caption as the value
+ json_record = json.dumps({base_filename: caption}, ensure_ascii=False)
+ # Write the JSON object to the file, one per line
+ f.write(json_record + "\n")
+
+
+def save_args(args, path_dir):
+ argsDict = args.__dict__
+ with open(path_dir + "setting.txt", "w") as f:
+ f.writelines("------------------ start ------------------" + "\n")
+ for eachArg, value in argsDict.items():
+ f.writelines(eachArg + " : " + str(value) + "\n")
+ f.writelines("------------------- end -------------------")
+
+
+# DATA PROCESSING
+def construct_prompt(sample, config):
+ question = sample["question"]
+ options = ast.literal_eval(sample["options"])
+ example = ""
+ if sample["question_type"] == "multiple-choice":
+ start_chr = "A"
+ prediction_range = []
+ index2ans = {}
+ for option in options:
+ prediction_range.append(start_chr)
+ example += f"({start_chr}) {option}\n"
+ index2ans[start_chr] = option
+ start_chr = chr(ord(start_chr) + 1)
+ empty_prompt_sample_structure = config["multi_choice_example_format"]
+ empty_prompt = empty_prompt_sample_structure.format(question, example)
+ res_dict = {}
+ res_dict["index2ans"] = index2ans
+ res_dict["correct_choice"] = sample["answer"]
+ res_dict["all_choices"] = prediction_range
+ res_dict["empty_prompt"] = empty_prompt
+ if config["task_instructions"]:
+ res_dict["final_input_prompt"] = (
+ config["task_instructions"].strip() + "\n\n" + empty_prompt
+ )
+ else:
+ res_dict["final_input_prompt"] = empty_prompt
+
+ res_dict["gt_content"] = options[ord(sample["answer"].upper()) - ord("A")]
+ else:
+ empty_prompt_sample_structure = config["short_ans_example_format"]
+ empty_prompt = empty_prompt_sample_structure.format(question)
+ res_dict = {}
+ res_dict["empty_prompt"] = empty_prompt
+ if config["task_instructions"]:
+ res_dict["final_input_prompt"] = (
+ config["task_instructions"].strip() + "\n\n" + empty_prompt
+ )
+ else:
+ res_dict["final_input_prompt"] = empty_prompt
+ res_dict["gt_content"] = sample["answer"]
+
+ res_dict.update(sample)
+ return res_dict
diff --git a/benchmarks/mmmu/eval_config.yaml b/benchmarks/mmmu/eval_config.yaml
new file mode 100644
index 000000000000..89850e5fd549
--- /dev/null
+++ b/benchmarks/mmmu/eval_config.yaml
@@ -0,0 +1,18 @@
+# Adapted from
+# https://github.com/MMMU-Benchmark/MMMU
+
+task_instructions:
+- ""
+multi_choice_example_format:
+- "{}
+
+{}
+
+Answer with the option's letter from the given choices directly."
+
+short_ans_example_format:
+- "{}
+
+Answer the question using a single word or phrase."
+temperature:
+- 0
\ No newline at end of file
diff --git a/benchmarks/mmmu/eval_utils.py b/benchmarks/mmmu/eval_utils.py
new file mode 100644
index 000000000000..38824684823b
--- /dev/null
+++ b/benchmarks/mmmu/eval_utils.py
@@ -0,0 +1,691 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Adapted from
+# https://github.com/MMMU-Benchmark/MMMU
+
+"""Response Parsing and Evaluation for various models"""
+
+import gc
+import json
+import os
+import random
+import re
+from collections import defaultdict
+from typing import Any, Callable
+
+import numpy as np
+from data_utils import (
+ construct_prompt,
+ load_mmmu_dataset,
+ load_yaml,
+ process_single_sample,
+)
+from tqdm import tqdm
+
+
+# ----------- Default Configuration -------------
+class BenchmarkDefaults:
+ """Default values for benchmark parameters"""
+
+ # Dataset parameters
+ SPLIT = "validation"
+ SUBJECT = None
+ MAX_SAMPLES = -1
+ CONFIG_PATH = "eval_config.yaml"
+
+ # Generation parameters
+ SEED = 42
+ TEMPERATURE = 0.01
+ TOP_P = 0.9
+ TOP_K = None
+ MAX_TOKENS = 512
+ DO_SAMPLE = True
+
+ # Benchmark parameters
+ BATCH_SIZE = 1
+ OUTPUT_PATH_HF = "benchmark_results_hf.json"
+ OUTPUT_PATH_VLLM = "benchmark_results_vllm.json"
+
+ # vLLM specific defaults
+ MODEL = "Qwen/Qwen2.5-VL-3B-Instruct"
+
+ @classmethod
+ def get_common_args_dict(cls):
+ """Get common arguments as a dictionary"""
+ return {
+ "split": cls.SPLIT,
+ "subject": cls.SUBJECT,
+ "max_samples": cls.MAX_SAMPLES,
+ "config_path": cls.CONFIG_PATH,
+ "seed": cls.SEED,
+ "temperature": cls.TEMPERATURE,
+ "top_p": cls.TOP_P,
+ "top_k": cls.TOP_K,
+ "max_tokens": cls.MAX_TOKENS,
+ "do_sample": cls.DO_SAMPLE,
+ "batch_size": cls.BATCH_SIZE,
+ }
+
+ @classmethod
+ def get_hf_args_dict(cls):
+ """Get HuggingFace specific arguments"""
+ args = cls.get_common_args_dict()
+ args["output_path"] = cls.OUTPUT_PATH_HF
+ return args
+
+ @classmethod
+ def get_vllm_args_dict(cls):
+ """Get vLLM specific arguments"""
+ args = cls.get_common_args_dict()
+ args["model"] = cls.MODEL
+ args["top_p"] = cls.TOP_P
+ args["output_path"] = cls.OUTPUT_PATH_VLLM
+ return args
+
+
+# ----------- Process Multi-choice -------------
+def parse_multi_choice_response(response, all_choices, index2ans):
+ """
+ Parse the prediction from the generated response.
+ Return the predicted index e.g., A, B, C, D.
+ """
+ for char in [",", ".", "!", "?", ";", ":", "'"]:
+ response = response.strip(char)
+ response = " " + response + " " # add space to avoid partial match
+
+ index_ans = True
+ ans_with_brack = False
+ candidates = []
+ for choice in all_choices: # e.g., (A) (B) (C) (D)
+ if f"({choice})" in response:
+ candidates.append(choice)
+ ans_with_brack = True
+
+ if len(candidates) == 0:
+ for choice in all_choices: # e.g., A B C D
+ if f" {choice} " in response:
+ candidates.append(choice)
+
+ # if all above doesn't get candidates, check if the content
+ # is larger than 5 tokens and try to parse the example
+ if len(candidates) == 0 and len(response.split()) > 5:
+ for index, ans in index2ans.items():
+ if ans.lower() in response.lower():
+ candidates.append(index)
+ index_ans = False # it's content ans.
+
+ if len(candidates) == 0: # still not get answer, randomly choose one.
+ pred_index = random.choice(all_choices)
+ elif len(candidates) > 1:
+ start_indexes = []
+ if index_ans:
+ if ans_with_brack:
+ for can in candidates:
+ index = response.rfind(f"({can})")
+ start_indexes.append(index) # -1 will be ignored anyway
+ # start_indexes =
+ # [generated_response.index(f'({can})') for can in candidates]
+ else:
+ for can in candidates:
+ index = response.rfind(f" {can} ")
+ start_indexes.append(index)
+ else:
+ for can in candidates:
+ index = response.lower().rfind(index2ans[can].lower())
+ start_indexes.append(index)
+ # get the last one
+ pred_index = candidates[np.argmax(start_indexes)]
+ else: # if only one candidate, use it.
+ pred_index = candidates[0]
+
+ return pred_index
+
+
+# ----------- Process Open -------------
+def check_is_number(string):
+ """
+ Check if the given string a number.
+ """
+ try:
+ float(string.replace(",", ""))
+ return True
+ except ValueError:
+ # check if there's comma inside
+ return False
+
+
+def normalize_str(string):
+ """
+ Normalize the str to lower case and make them float numbers if possible.
+ """
+ # check if characters in the string
+
+ # if number, numerize it.
+ string = string.strip()
+
+ is_number = check_is_number(string)
+
+ if is_number:
+ string = string.replace(",", "")
+ string = float(string)
+ # leave 2 decimal
+ string = round(string, 2)
+ return [string]
+ else: # it's likely to be a string
+ # lower it
+ string = string.lower()
+ if len(string) == 1:
+ return [" " + string, string + " "] # avoid trivial matches
+ return [string]
+
+
+def extract_numbers(string):
+ """
+ Exact all forms of numbers from a string with regex.
+ """
+ # Pattern for numbers with commas
+ pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
+ # Pattern for scientific notation
+ pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
+ # Pattern for simple numbers without commas
+ pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
+
+ # Extract numbers with commas
+ numbers_with_commas = re.findall(pattern_commas, string)
+ # Extract numbers in scientific notation
+ numbers_scientific = re.findall(pattern_scientific, string)
+ # Extract simple numbers without commas
+ numbers_simple = re.findall(pattern_simple, string)
+
+ # Combine all extracted numbers
+ all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
+ return all_numbers
+
+
+def parse_open_response(response):
+ """
+ Parse the prediction from the generated response.
+ Return a list of predicted strings or numbers.
+ """
+
+ # content = content.strip("\n").strip(".").strip(" ")
+ def get_key_subresponses(response):
+ key_responses: list[str] = []
+ response = response.strip().strip(".").lower()
+ sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response)
+ indicators_of_keys = [
+ "could be ",
+ "so ",
+ "is ",
+ "thus ",
+ "therefore ",
+ "final ",
+ "answer ",
+ "result ",
+ ]
+ key_responses = []
+ for index, resp in enumerate(sub_responses):
+ # if last one, accept it's an equation
+ # (the entire response can be just one sentence with equation)
+ if index == len(sub_responses) - 1:
+ indicators_of_keys.extend(["="])
+ # the shortest response that may contain
+ # the answer (tail part of the response)
+ shortest_key_response = None
+ for indicator in indicators_of_keys:
+ if indicator in resp:
+ if not shortest_key_response:
+ shortest_key_response = resp.split(indicator)[-1].strip()
+ else:
+ if len(resp.split(indicator)[-1].strip()) < len(
+ shortest_key_response
+ ):
+ shortest_key_response = resp.split(indicator)[-1].strip()
+ # key_responses.append(resp.split(indicator)[1].strip())
+
+ if shortest_key_response and shortest_key_response.strip() not in [
+ ":",
+ ",",
+ ".",
+ "!",
+ "?",
+ ";",
+ ":",
+ "'",
+ ]:
+ key_responses.append(shortest_key_response)
+ if len(key_responses) == 0: # did not found any
+ return [response]
+ return key_responses
+
+ # pdb.set_trace()
+ key_responses = get_key_subresponses(response)
+
+ pred_list = key_responses.copy() # keep the original string response
+ for resp in key_responses:
+ pred_list.extend(extract_numbers(resp))
+
+ tmp_pred_list = []
+ for i in range(len(pred_list)):
+ tmp_pred_list.extend(normalize_str(pred_list[i]))
+ pred_list = tmp_pred_list
+
+ # remove duplicates
+ pred_list = list(set(pred_list))
+
+ return pred_list
+
+
+# ----------- Evaluation -------------
+
+
+def eval_multi_choice(gold_i, pred_i):
+ """
+ Evaluate a multiple choice instance.
+ """
+ correct = False
+ # only they are exactly the same, we consider it as correct
+ if isinstance(gold_i, list):
+ for answer in gold_i:
+ if answer == pred_i:
+ correct = True
+ break
+ else: # gold_i is a string
+ if gold_i == pred_i:
+ correct = True
+ return correct
+
+
+def eval_open(gold_i, pred_i):
+ """
+ Evaluate an open question instance
+ """
+ correct = False
+ if isinstance(gold_i, list):
+ # use float to avoid trivial matches
+ norm_answers = []
+ for answer in gold_i:
+ norm_answers.extend(normalize_str(answer))
+ else:
+ norm_answers = normalize_str(gold_i)
+ for pred in pred_i: # pred is already normalized in parse response phase
+ if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
+ for norm_ans in norm_answers:
+ # only see if the string answer in the string pred
+ if isinstance(norm_ans, str) and norm_ans in pred:
+ if not correct:
+ correct = True
+ break
+ else: # it's a float number
+ if pred in norm_answers:
+ if not correct:
+ correct = True
+ break
+ return correct
+
+
+# ----------- Batch Evaluation -------------
+def evaluate(samples):
+ """
+ Batch evaluation for multiple choice and open questions.
+ """
+ pred_correct = 0
+ judge_dict = dict()
+ for sample in samples:
+ gold_i = sample["answer"]
+ pred_i = sample["parsed_pred"]
+ if sample["question_type"] == "multiple-choice":
+ correct = eval_multi_choice(gold_i, pred_i)
+ else: # open question
+ correct = eval_open(gold_i, pred_i)
+
+ if correct:
+ judge_dict[sample["id"]] = "Correct"
+ pred_correct += 1
+ else:
+ judge_dict[sample["id"]] = "Wrong"
+
+ if len(samples) == 0:
+ return {"acc": 0}
+ return judge_dict, {"acc": pred_correct / len(samples)}
+
+
+# ----------- Calculate Accuracy -------------
+def calculate_ins_level_acc(results: dict):
+ """Calculate the instruction level accuracy for given Subject results"""
+ acc = 0
+ ins_num = 0
+ for cat_results in results.values():
+ acc += cat_results["acc"] * cat_results["num_example"]
+ ins_num += cat_results["num_example"]
+ if ins_num == 0:
+ return 0
+ return acc / ins_num
+
+
+# ----------- Common Benchmark Logic -------------
+def run_benchmark(
+ samples: list[dict],
+ config: dict,
+ args: Any,
+ generate_func: Callable[[list[str]], list[str]],
+ batch_size: int = 1,
+ subject: str | None = None,
+ output_path: str = "benchmark_results.json",
+ model_info: dict | None = None,
+) -> dict:
+ """
+ Common benchmark logic for processing samples and evaluating results.
+
+ Args:
+ samples: List of dataset samples
+ config: Evaluation configuration
+ args: Arguments object containing generation parameters
+ generate_func: Function that takes (prompts) and returns responses
+ batch_size: Batch size for processing
+ subject: Subject name for filtering results
+ output_path: Path to save results
+ model_info: Additional model information to save
+
+ Returns:
+ dictionary containing results, metrics, and other information
+ """
+ results = []
+
+ # Set fixed seed for reproducibility
+ if hasattr(args, "seed"):
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ # Process samples in batches
+ batch_count = (len(samples) + batch_size - 1) // batch_size
+ for i in tqdm(
+ range(0, len(samples), batch_size),
+ desc="Processing batches",
+ total=batch_count,
+ unit="batch",
+ ):
+ batch_samples = samples[i : i + batch_size]
+ batch_prompts = []
+
+ # Prepare batch prompts and images
+ batch_prompts = []
+ batch_images = []
+ for sample in batch_samples:
+ prompt_data = construct_prompt(sample, config)
+ prompt = prompt_data["final_input_prompt"]
+ batch_prompts.append(prompt)
+ batch_images.append(sample.get("image")) # Get image data if available
+
+ # Store prompt data for later use
+ sample["_prompt_data"] = prompt_data
+ sample["_prompt"] = prompt
+
+ # Generate responses using the provided function
+ # Check if generate_func accepts images parameter (for vLLM) or not (for HF)
+ try:
+ responses = generate_func(batch_prompts, batch_images)
+ except TypeError:
+ # Fallback for functions that only accept prompts
+ responses = generate_func(batch_prompts)
+
+ # Process outputs
+ for j, response in enumerate(responses):
+ sample = batch_samples[j]
+ prompt_data = sample["_prompt_data"]
+
+ # Parse response based on question type
+ if sample["question_type"] == "multiple-choice":
+ parsed_pred = parse_multi_choice_response(
+ response, prompt_data["all_choices"], prompt_data["index2ans"]
+ )
+ else:
+ parsed_pred = parse_open_response(response)
+
+ # Store results
+ result = {
+ "id": sample["id"],
+ "question": sample["question"],
+ "answer": sample["answer"],
+ "question_type": sample["question_type"],
+ "response": response,
+ "parsed_pred": parsed_pred,
+ "prompt": sample["_prompt"],
+ "subject": sample.get("subject", "unknown"),
+ }
+ results.append(result)
+
+ # Clean up memory periodically
+ if i % (batch_size * 10) == 0:
+ gc.collect()
+
+ # Evaluate results
+ judge_dict, metrics = evaluate(results)
+
+ # Print results
+ print("\nEvaluation Results:")
+ print(f"Accuracy: {metrics['acc']:.4f}")
+
+ # Group results by subject if multiple subjects
+ if subject is None:
+ subject_results: dict[str, list[dict]] = defaultdict(list)
+ for result in results:
+ subj = result.get("subject", "unknown")
+ subject_results[subj].append(result)
+
+ print("\nResults by Subject:")
+ for subj, subject_samples in subject_results.items():
+ subject_judge_dict, subject_metrics = evaluate(subject_samples)
+ print(
+ f"{subj}: {subject_metrics['acc']:.4f} ({len(subject_samples)} samples)"
+ )
+
+ # Prepare final results
+ final_results = {
+ "results": results,
+ "metrics": metrics,
+ "judge_dict": judge_dict,
+ "args": {},
+ }
+
+ # Add model info and args
+ if model_info:
+ final_results["args"].update(model_info)
+
+ if hasattr(args, "__dict__"):
+ # Add relevant args
+ for attr in [
+ "seed",
+ "max_samples",
+ "temperature",
+ "top_p",
+ "max_tokens",
+ "max_new_tokens",
+ ]:
+ if hasattr(args, attr):
+ final_results["args"][attr] = getattr(args, attr)
+
+ # Save results
+ with open(output_path, "w") as f:
+ json.dump(final_results, f, indent=2)
+
+ print(f"Results saved to {output_path}")
+
+ return final_results
+
+
+def load_benchmark_dataset(
+ split: str = "validation", subject: str | None = None, max_samples: int = -1
+):
+ """
+ Load and prepare MMMU dataset for benchmarking.
+
+ Args:
+ split: Dataset split to use
+ subject: Specific subject to evaluate
+ max_samples: Maximum number of samples to process (-1 for all)
+
+ Returns:
+ List of processed samples
+ """
+ print("Loading MMMU dataset from HuggingFace Hub...")
+ print(f"Split: {split}, Subject: {subject}")
+
+ datasets_dict = load_mmmu_dataset(subset=split, subject=subject)
+
+ # Convert dataset samples to our format
+ samples = []
+ for subject, dataset in datasets_dict.items():
+ for sample in dataset:
+ sample = process_single_sample(sample)
+ sample["subject"] = subject
+ samples.append(sample)
+
+ # Limit number of samples if specified
+ if max_samples > 0:
+ samples = samples[:max_samples]
+
+ print(f"Processing {len(samples)} samples...")
+ return samples
+
+
+def get_message(prompt, image):
+ split_prompt = prompt.split("")
+ content = [{"type": "text", "text": s} for s in split_prompt]
+ content.insert(1, {"type": "image", "image": image} if image is not None else None)
+ messages = [{"role": "user", "content": content}]
+ if image is None:
+ messages[0]["content"] = [{"type": "text", "text": prompt}]
+ return messages
+
+
+def load_benchmark_config(config_path: str = "eval_config.yaml"):
+ """
+ Load evaluation configuration.
+
+ Args:
+ config_path: Path to configuration file
+
+ Returns:
+ Configuration dictionary
+ """
+ if os.path.exists(config_path):
+ config = load_yaml(config_path)
+ else:
+ # Default config
+ config = {
+ "multi_choice_example_format": "Question: {}\nOptions:\n{}\nAnswer:",
+ "short_ans_example_format": "Question: {}\nAnswer:",
+ "task_instructions": "Please answer the following\
+ question based on the given information.",
+ }
+ for key, value in config.items():
+ if key != "eval_params" and isinstance(value, list):
+ assert len(value) == 1, "key {} has more than one value".format(key)
+ config[key] = value[0]
+ return config
+
+
+def add_common_benchmark_args(parser, framework: str = "hf"):
+ """
+ Add common benchmark arguments to a parser.
+
+ Args:
+ parser: ArgumentParser instance
+ framework: "hf", "vllm"
+ """
+ defaults = BenchmarkDefaults()
+
+ # Dataset arguments
+ benchmark_group = parser.add_argument_group("Benchmark parameters")
+ benchmark_group.add_argument(
+ "--model", type=str, default=defaults.MODEL, help="model name"
+ )
+ benchmark_group.add_argument(
+ "--split",
+ type=str,
+ default=defaults.SPLIT,
+ choices=["validation", "test", "dev"],
+ help="Dataset split to use",
+ )
+ benchmark_group.add_argument(
+ "--subject",
+ type=str,
+ default=defaults.SUBJECT,
+ help="Specific subject to evaluate (e.g., 'Art', 'Biology')."
+ "If None, evaluates all subjects",
+ )
+ benchmark_group.add_argument(
+ "--max-samples",
+ type=int,
+ default=defaults.MAX_SAMPLES,
+ help="Maximum number of samples to process (-1 for all)",
+ )
+ benchmark_group.add_argument(
+ "--config-path",
+ type=str,
+ default=defaults.CONFIG_PATH,
+ help="Path to evaluation config file",
+ )
+ benchmark_group.add_argument(
+ "--seed",
+ type=int,
+ default=defaults.SEED,
+ help="Random seed for reproducibility",
+ )
+
+ # Generation arguments
+ sampling_group = parser.add_argument_group("Generation parameters")
+ sampling_group.add_argument(
+ "--temperature",
+ type=float,
+ default=defaults.TEMPERATURE,
+ help="Temperature for sampling (0.0 = deterministic)",
+ )
+ sampling_group.add_argument(
+ "--max-tokens",
+ type=int,
+ default=defaults.MAX_TOKENS,
+ help="Maximum number of tokens to generate",
+ )
+ sampling_group.add_argument(
+ "--top-p",
+ type=float,
+ default=defaults.TOP_P,
+ help="Top-p (nucleus) sampling parameter",
+ )
+ sampling_group.add_argument(
+ "--top-k", type=int, default=defaults.TOP_K, help="Top-k sampling parameter"
+ )
+
+ if framework == "hf":
+ # HuggingFace specific args
+ benchmark_group.add_argument(
+ "--output-path",
+ type=str,
+ default=defaults.OUTPUT_PATH_HF,
+ help="Path to save the results",
+ )
+ sampling_group.add_argument(
+ "--do-sample",
+ action="store_true",
+ default=defaults.DO_SAMPLE,
+ help="Whether to use sampling (vs greedy decoding)",
+ )
+ elif framework == "vllm":
+ # vLLM specific args
+ benchmark_group.add_argument(
+ "--output-path",
+ type=str,
+ default=defaults.OUTPUT_PATH_VLLM,
+ help="Path to save the results",
+ )
+ benchmark_group.add_argument(
+ "--batch-size",
+ type=int,
+ default=defaults.BATCH_SIZE,
+ help="Batch size for inference",
+ )
+
+ return parser