diff --git a/benchmarks/mmmu/benchmark_hf.py b/benchmarks/mmmu/benchmark_hf.py new file mode 100644 index 000000000000..7c43c4e12c8d --- /dev/null +++ b/benchmarks/mmmu/benchmark_hf.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch +from eval_utils import ( + add_common_benchmark_args, + get_message, + load_benchmark_config, + load_benchmark_dataset, + run_benchmark, +) +from transformers import AutoModelForImageTextToText, AutoProcessor, set_seed + +from vllm.utils import FlexibleArgumentParser + + +def load_model_and_processor(model_name: str): + """Load HuggingFace Vision-Language model and processor""" + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + model = None + for auto_class in [AutoModelForImageTextToText]: + try: + model = auto_class.from_pretrained( + model_name, torch_dtype="auto", trust_remote_code=True + ) + print(f"Successfully loaded model with {auto_class.__name__}") + break + except Exception: + continue + + if model is None: + raise ValueError( + f"Could not load model {model_name} with any available auto class" + ) + + model = model.eval().cuda() + + return model, processor + + +def generate_response( + model, + processor, + prompt: str, + image, + max_tokens: int, + temperature: float, + top_p: float, + top_k: Optional[int], + do_sample: bool, + seed: int, +) -> str: + """Generate response using HuggingFace Vision-Language model""" + # Set seed for reproducibility + set_seed(seed) + + messages = get_message(prompt, image) + + # Apply chat template + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + # Process inputs + inputs = processor( + text=[text], + images=[image] if image is not None else None, + return_tensors="pt", + padding=True, + ) + inputs = inputs.to(model.device) + + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=max_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + top_k=top_k, + ) + + # Extract generated tokens (excluding input tokens) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] + for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + + response = processor.batch_decode( + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + )[0] + + return response.strip() + + +def hf_generate_func(model, processor, generation_params): + """Create a generation function for HuggingFace VL models + that matches the common interface""" + + def generate(prompts: list[str], images: Optional[list] = None) -> list[str]: + """Generate responses using HuggingFace VL model""" + responses = [] + if images is None: + images = [None] * len(prompts) + + for prompt, image in zip(prompts, images): + response = generate_response( + model, + processor, + prompt, + image, + max_tokens=generation_params.max_tokens, + temperature=generation_params.temperature, + top_p=generation_params.top_p, + top_k=generation_params.top_k, + do_sample=generation_params.do_sample, + seed=generation_params.seed, + ) + responses.append(response) + return responses + + return generate + + +def main(args): + # Load model and processor + print(f"Loading model from {args.model}...") + model, processor = load_model_and_processor(args.model) + + # Load evaluation config + config = load_benchmark_config( + args.config_path if hasattr(args, "config_path") else "eval_config.yaml" + ) + + # Load dataset + samples = load_benchmark_dataset( + split=args.split, subject=args.subject, max_samples=args.max_samples + ) + + # Create generation function + generate_func = hf_generate_func(model, processor, args) + + # Model info for saving + model_info = { + "model": args.model, + "split": args.split, + "subject": args.subject, + "max_samples": args.max_samples, + } + + # Run benchmark using common logic + results = run_benchmark( + samples=samples, + config=config, + args=args, + generate_func=generate_func, + batch_size=1, # HF processes one at a time + subject=args.subject, + output_path=args.output_path, + model_info=model_info, + ) + + return results + + +def invoke_main() -> None: + parser = FlexibleArgumentParser( + description="Benchmark HuggingFace models on MMMU dataset from HuggingFace Hub" + ) + + # Add common benchmark arguments + parser = add_common_benchmark_args(parser, framework="hf") + + args = parser.parse_args() + main(args) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/benchmarks/mmmu/benchmark_vllm.py b/benchmarks/mmmu/benchmark_vllm.py new file mode 100644 index 000000000000..02621864e559 --- /dev/null +++ b/benchmarks/mmmu/benchmark_vllm.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from eval_utils import ( + add_common_benchmark_args, + get_message, + load_benchmark_config, + load_benchmark_dataset, + run_benchmark, +) +from transformers import AutoTokenizer + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def main(args: dict): + # get common args + seed = args.get("seed") + model_name = args.get("model") + + # Pop sampling arguments + max_tokens = args.pop("max_tokens") + temperature = args.pop("temperature") + top_p = args.pop("top_p") + top_k = args.pop("top_k") + + # Pop benchmark specific arguments + split = args.pop("split") + subject = args.pop("subject") + max_samples = args.pop("max_samples") + output_path = args.pop("output_path") + config_path = args.pop("config_path") + batch_size = args.pop("batch_size") + + # Create an LLM with remaining args + print("Loading vLLM model...") + args["disable_mm_preprocessor_cache"] = True + llm = LLM(**args) + + # Load tokenizer for chat template + print(f"Loading tokenizer from {model_name}...") + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + # Create sampling params using the LLM instance + sampling_params = llm.get_default_sampling_params() + if max_tokens is not None: + sampling_params.max_tokens = max_tokens + if temperature is not None: + sampling_params.temperature = temperature + if top_p is not None: + sampling_params.top_p = top_p + if top_k is not None: + sampling_params.top_k = top_k + if seed is not None: + sampling_params.seed = seed + + # Store args for common benchmark function + class Args: + def __init__(self): + self.seed = seed + self.max_tokens = max_tokens + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + + benchmark_args = Args() + + # Load evaluation config + config = load_benchmark_config(config_path) + + # Load dataset + samples = load_benchmark_dataset( + split=split, subject=subject, max_samples=max_samples + ) + + # Model info for saving + model_info = { + "model": model_name, + "split": split, + "subject": subject, + "max_samples": max_samples, + "batch_size": batch_size, + } + + # Create a generation function that matches the HF interface + def generate_with_params(prompts: list[str], images: list = None) -> list[str]: + """ + Generate responses for prompts with associated images. + Args: + prompts: List of prompt strings + images: List of image data (can be None for text-only) + Returns: + List of response strings + """ + # Prepare inputs for vLLM batch inference + inputs = [] + if images is None: + images = [None] * len(prompts) + + for prompt, image in zip(prompts, images): + messages = get_message(prompt, image) + try: + formatted_prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception as e: + print( + f"Warning: Failed to apply chat template,\ + using original prompt: {e}" + ) + formatted_prompt = prompt + + input_data = {"prompt": formatted_prompt} + if image is not None: + input_data["multi_modal_data"] = {"image": image} + inputs.append(input_data) + + # Use our pre-configured sampling_params + outputs = llm.generate(inputs, sampling_params, use_tqdm=False) + responses = [] + for output in outputs: + response = output.outputs[0].text.strip() + responses.append(response) + return responses + + # Run benchmark + results = run_benchmark( + samples=samples, + config=config, + args=benchmark_args, + generate_func=generate_with_params, + batch_size=batch_size, + subject=subject, + output_path=output_path, + model_info=model_info, + ) + + return results + + +def create_parser(): + parser = FlexibleArgumentParser( + description="Benchmark vLLM models on MMMU dataset using offline inference", + conflict_handler="resolve", + ) + + # Add engine args first (these provide base vLLM functionality) + EngineArgs.add_cli_args(parser) + + # Add common benchmark arguments (these will override conflicting vLLM defaults) + parser = add_common_benchmark_args(parser, framework="vllm") + + return parser + + +def invoke_main() -> None: + parser = create_parser() + args: dict = vars(parser.parse_args()) + main(args) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/benchmarks/mmmu/data_utils.py b/benchmarks/mmmu/data_utils.py new file mode 100644 index 000000000000..38378f30b4e9 --- /dev/null +++ b/benchmarks/mmmu/data_utils.py @@ -0,0 +1,234 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/MMMU-Benchmark/MMMU + +"""Utils for data load, save, and process (e.g., prompt construction)""" + +import ast +import json +import os +import re +from typing import Optional + +import yaml +from datasets import load_dataset + +DOMAIN_CAT2SUB_CAT = { + "Art and Design": ["Art", "Art_Theory", "Design", "Music"], + "Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"], + "Science": [ + "Biology", + "Chemistry", + "Geography", + "Math", + "Physics", + ], + "Health and Medicine": [ + "Basic_Medical_Science", + "Clinical_Medicine", + "Diagnostics_and_Laboratory_Medicine", + "Pharmacy", + "Public_Health", + ], + "Humanities and Social Science": [ + "History", + "Literature", + "Sociology", + "Psychology", + ], + "Tech and Engineering": [ + "Agriculture", + "Architecture_and_Engineering", + "Computer_Science", + "Electronics", + "Energy_and_Power", + "Materials", + "Mechanical_Engineering", + ], +} + + +CAT_SHORT2LONG = { + "acc": "Accounting", + "agri": "Agriculture", + "arch": "Architecture_and_Engineering", + "art": "Art", + "art_theory": "Art_Theory", + "bas_med": "Basic_Medical_Science", + "bio": "Biology", + "chem": "Chemistry", + "cli_med": "Clinical_Medicine", + "cs": "Computer_Science", + "design": "Design", + "diag_med": "Diagnostics_and_Laboratory_Medicine", + "econ": "Economics", + "elec": "Electronics", + "ep": "Energy_and_Power", + "fin": "Finance", + "geo": "Geography", + "his": "History", + "liter": "Literature", + "manage": "Manage", + "mark": "Marketing", + "mate": "Materials", + "math": "Math", + "mech": "Mechanical_Engineering", + "music": "Music", + "phar": "Pharmacy", + "phys": "Physics", + "psy": "Psychology", + "pub_health": "Public_Health", + "socio": "Sociology", +} + + +def load_mmmu_dataset(subset: str = "validation", subject: Optional[str] = None): + """Load MMMU dataset from HuggingFace Hub""" + available_subjects = list(CAT_SHORT2LONG.values()) if subject is None else [subject] + datasets_dict = {} + for subj in set(available_subjects): + subj_dataset = load_dataset("MMMU/MMMU", subj, split=subset) + datasets_dict[subj] = subj_dataset + return datasets_dict + + +def get_multi_choice_info(options): + """ + Given the list of options for multiple choice question + Return the index2ans and all_choices + """ + + start_chr = "A" + all_choices = [] + index2ans = {} + for i, option in enumerate(options): + index2ans[chr(ord(start_chr) + i)] = option + all_choices.append(chr(ord(start_chr) + i)) + + return index2ans, all_choices + + +def load_yaml(file_path): + with open(file_path) as stream: + try: + yaml_dict = yaml.safe_load(stream) + except yaml.YAMLError as exc: + print(exc) + + return yaml_dict + + +def parse_img_path(text): + matches = re.findall("", text) + return matches + + +def process_single_sample(data): + question = data["question"] + o_imgs_paths = [] + for option in data["options"]: + current_o_imgs_paths = parse_img_path(option) + for img_path in current_o_imgs_paths: + o_imgs_paths.append(img_path) + + if len(o_imgs_paths) > 1: # multiple images in options, used for random selection + return { + "id": data["id"], + "question": question, + "options": data["options"], + "answer": data["answer"], + "image": None, + "question_type": data["question_type"], + } + else: + return { + "id": data["id"], + "question": question, + "options": data["options"], + "answer": data["answer"], + "image": data["image_1"], + "question_type": data["question_type"], + } + + +# DATA SAVING +def save_json(filename, ds): + with open(filename, "w") as f: + json.dump(ds, f, indent=4) + + +def save_jsonl(filename, data): + """ + Save a dictionary of data to a JSON Lines file with the filename as + key and caption as value. + + Args: + filename (str): The path to the file where the data should be saved. + data (dict): The dictionary containing the data to save where key + is the image path and value is the caption. + """ + with open(filename, "w", encoding="utf-8") as f: + for img_path, caption in data.items(): + # Extract the base filename without the extension + base_filename = os.path.basename(img_path) + # Create a JSON object with the filename as the key and caption as the value + json_record = json.dumps({base_filename: caption}, ensure_ascii=False) + # Write the JSON object to the file, one per line + f.write(json_record + "\n") + + +def save_args(args, path_dir): + argsDict = args.__dict__ + with open(path_dir + "setting.txt", "w") as f: + f.writelines("------------------ start ------------------" + "\n") + for eachArg, value in argsDict.items(): + f.writelines(eachArg + " : " + str(value) + "\n") + f.writelines("------------------- end -------------------") + + +# DATA PROCESSING +def construct_prompt(sample, config): + question = sample["question"] + options = ast.literal_eval(sample["options"]) + example = "" + if sample["question_type"] == "multiple-choice": + start_chr = "A" + prediction_range = [] + index2ans = {} + for option in options: + prediction_range.append(start_chr) + example += f"({start_chr}) {option}\n" + index2ans[start_chr] = option + start_chr = chr(ord(start_chr) + 1) + empty_prompt_sample_structure = config["multi_choice_example_format"] + empty_prompt = empty_prompt_sample_structure.format(question, example) + res_dict = {} + res_dict["index2ans"] = index2ans + res_dict["correct_choice"] = sample["answer"] + res_dict["all_choices"] = prediction_range + res_dict["empty_prompt"] = empty_prompt + if config["task_instructions"]: + res_dict["final_input_prompt"] = ( + config["task_instructions"].strip() + "\n\n" + empty_prompt + ) + else: + res_dict["final_input_prompt"] = empty_prompt + + res_dict["gt_content"] = options[ord(sample["answer"].upper()) - ord("A")] + else: + empty_prompt_sample_structure = config["short_ans_example_format"] + empty_prompt = empty_prompt_sample_structure.format(question) + res_dict = {} + res_dict["empty_prompt"] = empty_prompt + if config["task_instructions"]: + res_dict["final_input_prompt"] = ( + config["task_instructions"].strip() + "\n\n" + empty_prompt + ) + else: + res_dict["final_input_prompt"] = empty_prompt + res_dict["gt_content"] = sample["answer"] + + res_dict.update(sample) + return res_dict diff --git a/benchmarks/mmmu/eval_config.yaml b/benchmarks/mmmu/eval_config.yaml new file mode 100644 index 000000000000..89850e5fd549 --- /dev/null +++ b/benchmarks/mmmu/eval_config.yaml @@ -0,0 +1,18 @@ +# Adapted from +# https://github.com/MMMU-Benchmark/MMMU + +task_instructions: +- "" +multi_choice_example_format: +- "{} + +{} + +Answer with the option's letter from the given choices directly." + +short_ans_example_format: +- "{} + +Answer the question using a single word or phrase." +temperature: +- 0 \ No newline at end of file diff --git a/benchmarks/mmmu/eval_utils.py b/benchmarks/mmmu/eval_utils.py new file mode 100644 index 000000000000..38824684823b --- /dev/null +++ b/benchmarks/mmmu/eval_utils.py @@ -0,0 +1,691 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/MMMU-Benchmark/MMMU + +"""Response Parsing and Evaluation for various models""" + +import gc +import json +import os +import random +import re +from collections import defaultdict +from typing import Any, Callable + +import numpy as np +from data_utils import ( + construct_prompt, + load_mmmu_dataset, + load_yaml, + process_single_sample, +) +from tqdm import tqdm + + +# ----------- Default Configuration ------------- +class BenchmarkDefaults: + """Default values for benchmark parameters""" + + # Dataset parameters + SPLIT = "validation" + SUBJECT = None + MAX_SAMPLES = -1 + CONFIG_PATH = "eval_config.yaml" + + # Generation parameters + SEED = 42 + TEMPERATURE = 0.01 + TOP_P = 0.9 + TOP_K = None + MAX_TOKENS = 512 + DO_SAMPLE = True + + # Benchmark parameters + BATCH_SIZE = 1 + OUTPUT_PATH_HF = "benchmark_results_hf.json" + OUTPUT_PATH_VLLM = "benchmark_results_vllm.json" + + # vLLM specific defaults + MODEL = "Qwen/Qwen2.5-VL-3B-Instruct" + + @classmethod + def get_common_args_dict(cls): + """Get common arguments as a dictionary""" + return { + "split": cls.SPLIT, + "subject": cls.SUBJECT, + "max_samples": cls.MAX_SAMPLES, + "config_path": cls.CONFIG_PATH, + "seed": cls.SEED, + "temperature": cls.TEMPERATURE, + "top_p": cls.TOP_P, + "top_k": cls.TOP_K, + "max_tokens": cls.MAX_TOKENS, + "do_sample": cls.DO_SAMPLE, + "batch_size": cls.BATCH_SIZE, + } + + @classmethod + def get_hf_args_dict(cls): + """Get HuggingFace specific arguments""" + args = cls.get_common_args_dict() + args["output_path"] = cls.OUTPUT_PATH_HF + return args + + @classmethod + def get_vllm_args_dict(cls): + """Get vLLM specific arguments""" + args = cls.get_common_args_dict() + args["model"] = cls.MODEL + args["top_p"] = cls.TOP_P + args["output_path"] = cls.OUTPUT_PATH_VLLM + return args + + +# ----------- Process Multi-choice ------------- +def parse_multi_choice_response(response, all_choices, index2ans): + """ + Parse the prediction from the generated response. + Return the predicted index e.g., A, B, C, D. + """ + for char in [",", ".", "!", "?", ";", ":", "'"]: + response = response.strip(char) + response = " " + response + " " # add space to avoid partial match + + index_ans = True + ans_with_brack = False + candidates = [] + for choice in all_choices: # e.g., (A) (B) (C) (D) + if f"({choice})" in response: + candidates.append(choice) + ans_with_brack = True + + if len(candidates) == 0: + for choice in all_choices: # e.g., A B C D + if f" {choice} " in response: + candidates.append(choice) + + # if all above doesn't get candidates, check if the content + # is larger than 5 tokens and try to parse the example + if len(candidates) == 0 and len(response.split()) > 5: + for index, ans in index2ans.items(): + if ans.lower() in response.lower(): + candidates.append(index) + index_ans = False # it's content ans. + + if len(candidates) == 0: # still not get answer, randomly choose one. + pred_index = random.choice(all_choices) + elif len(candidates) > 1: + start_indexes = [] + if index_ans: + if ans_with_brack: + for can in candidates: + index = response.rfind(f"({can})") + start_indexes.append(index) # -1 will be ignored anyway + # start_indexes = + # [generated_response.index(f'({can})') for can in candidates] + else: + for can in candidates: + index = response.rfind(f" {can} ") + start_indexes.append(index) + else: + for can in candidates: + index = response.lower().rfind(index2ans[can].lower()) + start_indexes.append(index) + # get the last one + pred_index = candidates[np.argmax(start_indexes)] + else: # if only one candidate, use it. + pred_index = candidates[0] + + return pred_index + + +# ----------- Process Open ------------- +def check_is_number(string): + """ + Check if the given string a number. + """ + try: + float(string.replace(",", "")) + return True + except ValueError: + # check if there's comma inside + return False + + +def normalize_str(string): + """ + Normalize the str to lower case and make them float numbers if possible. + """ + # check if characters in the string + + # if number, numerize it. + string = string.strip() + + is_number = check_is_number(string) + + if is_number: + string = string.replace(",", "") + string = float(string) + # leave 2 decimal + string = round(string, 2) + return [string] + else: # it's likely to be a string + # lower it + string = string.lower() + if len(string) == 1: + return [" " + string, string + " "] # avoid trivial matches + return [string] + + +def extract_numbers(string): + """ + Exact all forms of numbers from a string with regex. + """ + # Pattern for numbers with commas + pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b" + # Pattern for scientific notation + pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+" + # Pattern for simple numbers without commas + pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])" + + # Extract numbers with commas + numbers_with_commas = re.findall(pattern_commas, string) + # Extract numbers in scientific notation + numbers_scientific = re.findall(pattern_scientific, string) + # Extract simple numbers without commas + numbers_simple = re.findall(pattern_simple, string) + + # Combine all extracted numbers + all_numbers = numbers_with_commas + numbers_scientific + numbers_simple + return all_numbers + + +def parse_open_response(response): + """ + Parse the prediction from the generated response. + Return a list of predicted strings or numbers. + """ + + # content = content.strip("\n").strip(".").strip(" ") + def get_key_subresponses(response): + key_responses: list[str] = [] + response = response.strip().strip(".").lower() + sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response) + indicators_of_keys = [ + "could be ", + "so ", + "is ", + "thus ", + "therefore ", + "final ", + "answer ", + "result ", + ] + key_responses = [] + for index, resp in enumerate(sub_responses): + # if last one, accept it's an equation + # (the entire response can be just one sentence with equation) + if index == len(sub_responses) - 1: + indicators_of_keys.extend(["="]) + # the shortest response that may contain + # the answer (tail part of the response) + shortest_key_response = None + for indicator in indicators_of_keys: + if indicator in resp: + if not shortest_key_response: + shortest_key_response = resp.split(indicator)[-1].strip() + else: + if len(resp.split(indicator)[-1].strip()) < len( + shortest_key_response + ): + shortest_key_response = resp.split(indicator)[-1].strip() + # key_responses.append(resp.split(indicator)[1].strip()) + + if shortest_key_response and shortest_key_response.strip() not in [ + ":", + ",", + ".", + "!", + "?", + ";", + ":", + "'", + ]: + key_responses.append(shortest_key_response) + if len(key_responses) == 0: # did not found any + return [response] + return key_responses + + # pdb.set_trace() + key_responses = get_key_subresponses(response) + + pred_list = key_responses.copy() # keep the original string response + for resp in key_responses: + pred_list.extend(extract_numbers(resp)) + + tmp_pred_list = [] + for i in range(len(pred_list)): + tmp_pred_list.extend(normalize_str(pred_list[i])) + pred_list = tmp_pred_list + + # remove duplicates + pred_list = list(set(pred_list)) + + return pred_list + + +# ----------- Evaluation ------------- + + +def eval_multi_choice(gold_i, pred_i): + """ + Evaluate a multiple choice instance. + """ + correct = False + # only they are exactly the same, we consider it as correct + if isinstance(gold_i, list): + for answer in gold_i: + if answer == pred_i: + correct = True + break + else: # gold_i is a string + if gold_i == pred_i: + correct = True + return correct + + +def eval_open(gold_i, pred_i): + """ + Evaluate an open question instance + """ + correct = False + if isinstance(gold_i, list): + # use float to avoid trivial matches + norm_answers = [] + for answer in gold_i: + norm_answers.extend(normalize_str(answer)) + else: + norm_answers = normalize_str(gold_i) + for pred in pred_i: # pred is already normalized in parse response phase + if isinstance(pred, str): # if it's a string, then find if ans in the pred_i + for norm_ans in norm_answers: + # only see if the string answer in the string pred + if isinstance(norm_ans, str) and norm_ans in pred: + if not correct: + correct = True + break + else: # it's a float number + if pred in norm_answers: + if not correct: + correct = True + break + return correct + + +# ----------- Batch Evaluation ------------- +def evaluate(samples): + """ + Batch evaluation for multiple choice and open questions. + """ + pred_correct = 0 + judge_dict = dict() + for sample in samples: + gold_i = sample["answer"] + pred_i = sample["parsed_pred"] + if sample["question_type"] == "multiple-choice": + correct = eval_multi_choice(gold_i, pred_i) + else: # open question + correct = eval_open(gold_i, pred_i) + + if correct: + judge_dict[sample["id"]] = "Correct" + pred_correct += 1 + else: + judge_dict[sample["id"]] = "Wrong" + + if len(samples) == 0: + return {"acc": 0} + return judge_dict, {"acc": pred_correct / len(samples)} + + +# ----------- Calculate Accuracy ------------- +def calculate_ins_level_acc(results: dict): + """Calculate the instruction level accuracy for given Subject results""" + acc = 0 + ins_num = 0 + for cat_results in results.values(): + acc += cat_results["acc"] * cat_results["num_example"] + ins_num += cat_results["num_example"] + if ins_num == 0: + return 0 + return acc / ins_num + + +# ----------- Common Benchmark Logic ------------- +def run_benchmark( + samples: list[dict], + config: dict, + args: Any, + generate_func: Callable[[list[str]], list[str]], + batch_size: int = 1, + subject: str | None = None, + output_path: str = "benchmark_results.json", + model_info: dict | None = None, +) -> dict: + """ + Common benchmark logic for processing samples and evaluating results. + + Args: + samples: List of dataset samples + config: Evaluation configuration + args: Arguments object containing generation parameters + generate_func: Function that takes (prompts) and returns responses + batch_size: Batch size for processing + subject: Subject name for filtering results + output_path: Path to save results + model_info: Additional model information to save + + Returns: + dictionary containing results, metrics, and other information + """ + results = [] + + # Set fixed seed for reproducibility + if hasattr(args, "seed"): + random.seed(args.seed) + np.random.seed(args.seed) + + # Process samples in batches + batch_count = (len(samples) + batch_size - 1) // batch_size + for i in tqdm( + range(0, len(samples), batch_size), + desc="Processing batches", + total=batch_count, + unit="batch", + ): + batch_samples = samples[i : i + batch_size] + batch_prompts = [] + + # Prepare batch prompts and images + batch_prompts = [] + batch_images = [] + for sample in batch_samples: + prompt_data = construct_prompt(sample, config) + prompt = prompt_data["final_input_prompt"] + batch_prompts.append(prompt) + batch_images.append(sample.get("image")) # Get image data if available + + # Store prompt data for later use + sample["_prompt_data"] = prompt_data + sample["_prompt"] = prompt + + # Generate responses using the provided function + # Check if generate_func accepts images parameter (for vLLM) or not (for HF) + try: + responses = generate_func(batch_prompts, batch_images) + except TypeError: + # Fallback for functions that only accept prompts + responses = generate_func(batch_prompts) + + # Process outputs + for j, response in enumerate(responses): + sample = batch_samples[j] + prompt_data = sample["_prompt_data"] + + # Parse response based on question type + if sample["question_type"] == "multiple-choice": + parsed_pred = parse_multi_choice_response( + response, prompt_data["all_choices"], prompt_data["index2ans"] + ) + else: + parsed_pred = parse_open_response(response) + + # Store results + result = { + "id": sample["id"], + "question": sample["question"], + "answer": sample["answer"], + "question_type": sample["question_type"], + "response": response, + "parsed_pred": parsed_pred, + "prompt": sample["_prompt"], + "subject": sample.get("subject", "unknown"), + } + results.append(result) + + # Clean up memory periodically + if i % (batch_size * 10) == 0: + gc.collect() + + # Evaluate results + judge_dict, metrics = evaluate(results) + + # Print results + print("\nEvaluation Results:") + print(f"Accuracy: {metrics['acc']:.4f}") + + # Group results by subject if multiple subjects + if subject is None: + subject_results: dict[str, list[dict]] = defaultdict(list) + for result in results: + subj = result.get("subject", "unknown") + subject_results[subj].append(result) + + print("\nResults by Subject:") + for subj, subject_samples in subject_results.items(): + subject_judge_dict, subject_metrics = evaluate(subject_samples) + print( + f"{subj}: {subject_metrics['acc']:.4f} ({len(subject_samples)} samples)" + ) + + # Prepare final results + final_results = { + "results": results, + "metrics": metrics, + "judge_dict": judge_dict, + "args": {}, + } + + # Add model info and args + if model_info: + final_results["args"].update(model_info) + + if hasattr(args, "__dict__"): + # Add relevant args + for attr in [ + "seed", + "max_samples", + "temperature", + "top_p", + "max_tokens", + "max_new_tokens", + ]: + if hasattr(args, attr): + final_results["args"][attr] = getattr(args, attr) + + # Save results + with open(output_path, "w") as f: + json.dump(final_results, f, indent=2) + + print(f"Results saved to {output_path}") + + return final_results + + +def load_benchmark_dataset( + split: str = "validation", subject: str | None = None, max_samples: int = -1 +): + """ + Load and prepare MMMU dataset for benchmarking. + + Args: + split: Dataset split to use + subject: Specific subject to evaluate + max_samples: Maximum number of samples to process (-1 for all) + + Returns: + List of processed samples + """ + print("Loading MMMU dataset from HuggingFace Hub...") + print(f"Split: {split}, Subject: {subject}") + + datasets_dict = load_mmmu_dataset(subset=split, subject=subject) + + # Convert dataset samples to our format + samples = [] + for subject, dataset in datasets_dict.items(): + for sample in dataset: + sample = process_single_sample(sample) + sample["subject"] = subject + samples.append(sample) + + # Limit number of samples if specified + if max_samples > 0: + samples = samples[:max_samples] + + print(f"Processing {len(samples)} samples...") + return samples + + +def get_message(prompt, image): + split_prompt = prompt.split("") + content = [{"type": "text", "text": s} for s in split_prompt] + content.insert(1, {"type": "image", "image": image} if image is not None else None) + messages = [{"role": "user", "content": content}] + if image is None: + messages[0]["content"] = [{"type": "text", "text": prompt}] + return messages + + +def load_benchmark_config(config_path: str = "eval_config.yaml"): + """ + Load evaluation configuration. + + Args: + config_path: Path to configuration file + + Returns: + Configuration dictionary + """ + if os.path.exists(config_path): + config = load_yaml(config_path) + else: + # Default config + config = { + "multi_choice_example_format": "Question: {}\nOptions:\n{}\nAnswer:", + "short_ans_example_format": "Question: {}\nAnswer:", + "task_instructions": "Please answer the following\ + question based on the given information.", + } + for key, value in config.items(): + if key != "eval_params" and isinstance(value, list): + assert len(value) == 1, "key {} has more than one value".format(key) + config[key] = value[0] + return config + + +def add_common_benchmark_args(parser, framework: str = "hf"): + """ + Add common benchmark arguments to a parser. + + Args: + parser: ArgumentParser instance + framework: "hf", "vllm" + """ + defaults = BenchmarkDefaults() + + # Dataset arguments + benchmark_group = parser.add_argument_group("Benchmark parameters") + benchmark_group.add_argument( + "--model", type=str, default=defaults.MODEL, help="model name" + ) + benchmark_group.add_argument( + "--split", + type=str, + default=defaults.SPLIT, + choices=["validation", "test", "dev"], + help="Dataset split to use", + ) + benchmark_group.add_argument( + "--subject", + type=str, + default=defaults.SUBJECT, + help="Specific subject to evaluate (e.g., 'Art', 'Biology')." + "If None, evaluates all subjects", + ) + benchmark_group.add_argument( + "--max-samples", + type=int, + default=defaults.MAX_SAMPLES, + help="Maximum number of samples to process (-1 for all)", + ) + benchmark_group.add_argument( + "--config-path", + type=str, + default=defaults.CONFIG_PATH, + help="Path to evaluation config file", + ) + benchmark_group.add_argument( + "--seed", + type=int, + default=defaults.SEED, + help="Random seed for reproducibility", + ) + + # Generation arguments + sampling_group = parser.add_argument_group("Generation parameters") + sampling_group.add_argument( + "--temperature", + type=float, + default=defaults.TEMPERATURE, + help="Temperature for sampling (0.0 = deterministic)", + ) + sampling_group.add_argument( + "--max-tokens", + type=int, + default=defaults.MAX_TOKENS, + help="Maximum number of tokens to generate", + ) + sampling_group.add_argument( + "--top-p", + type=float, + default=defaults.TOP_P, + help="Top-p (nucleus) sampling parameter", + ) + sampling_group.add_argument( + "--top-k", type=int, default=defaults.TOP_K, help="Top-k sampling parameter" + ) + + if framework == "hf": + # HuggingFace specific args + benchmark_group.add_argument( + "--output-path", + type=str, + default=defaults.OUTPUT_PATH_HF, + help="Path to save the results", + ) + sampling_group.add_argument( + "--do-sample", + action="store_true", + default=defaults.DO_SAMPLE, + help="Whether to use sampling (vs greedy decoding)", + ) + elif framework == "vllm": + # vLLM specific args + benchmark_group.add_argument( + "--output-path", + type=str, + default=defaults.OUTPUT_PATH_VLLM, + help="Path to save the results", + ) + benchmark_group.add_argument( + "--batch-size", + type=int, + default=defaults.BATCH_SIZE, + help="Batch size for inference", + ) + + return parser