diff --git a/python/sglang/test/kits/mmmu_vlm_kit.py b/python/sglang/test/kits/mmmu_vlm_kit.py index b7415b0d4745..a1ca28fedc2d 100644 --- a/python/sglang/test/kits/mmmu_vlm_kit.py +++ b/python/sglang/test/kits/mmmu_vlm_kit.py @@ -1,8 +1,10 @@ import glob import json import os +import shutil import subprocess import tempfile +from pathlib import Path from types import SimpleNamespace from sglang.srt.environ import temp_set_env @@ -18,6 +20,78 @@ DEFAULT_MEM_FRACTION_STATIC = 0.8 +def _is_mmmu_parquet_corruption(error_output: str) -> bool: + """Check if error is due to MMMU parquet file corruption.""" + return ( + "ArrowInvalid" in error_output + and "Parquet magic bytes not found" in error_output + and ("MMMU" in error_output or "lmms-lab--MMMU" in error_output) + ) + + +def _cleanup_mmmu_dataset_cache(): + """Clean up corrupted MMMU dataset cache to allow fresh download.""" + # Priority 1: Check CI convention path /hf_home first (used in Docker containers) + ci_hf_home = Path("/hf_home/hub/datasets--lmms-lab--MMMU") + if ci_hf_home.exists(): + mmmu_cache_path = ci_hf_home + else: + # Priority 2: Use HF_HOME env var or default user cache + hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")) + mmmu_cache_path = Path(hf_home) / "hub" / "datasets--lmms-lab--MMMU" + + if mmmu_cache_path.exists(): + print(f"Detected corrupted MMMU parquet cache. Cleaning up: {mmmu_cache_path}") + try: + shutil.rmtree(mmmu_cache_path) + print(f"Successfully removed corrupted cache: {mmmu_cache_path}") + return True + except OSError as e: + print(f"Warning: Failed to remove cache {mmmu_cache_path}: {e}") + return False + else: + print(f"MMMU cache not found at {mmmu_cache_path}, skipping cleanup") + return False + + +def _run_lmms_eval_with_retry(cmd: list[str], timeout: int = 3600) -> None: + """Run lmms_eval command with automatic retry on MMMU parquet corruption.""" + try: + result = subprocess.run( + cmd, + check=True, + timeout=timeout, + capture_output=True, + text=True, + ) + # Print captured output to maintain visibility of successful runs + if result.stdout: + print(result.stdout, end="") + if result.stderr: + print(result.stderr, end="") + except subprocess.CalledProcessError as e: + error_output = e.stderr + e.stdout + if _is_mmmu_parquet_corruption(error_output): + print("Detected MMMU parquet corruption error. Attempting recovery...") + if _cleanup_mmmu_dataset_cache(): + print("Retrying lmms_eval with fresh download...") + with temp_set_env( + HF_HUB_OFFLINE="0", + HF_DATASETS_DOWNLOAD_MODE="force_redownload", + ): + subprocess.run(cmd, check=True, timeout=timeout) + else: + print( + f"Failed to cleanup corrupted MMMU cache. Error from lmms_eval:\nStdout:\n{e.stdout}\nStderr:\n{e.stderr}" + ) + raise + else: + print( + f"lmms_eval failed with an unhandled error.\nStdout:\n{e.stdout}\nStderr:\n{e.stderr}" + ) + raise + + class MMMUMixin: """Mixin for MMMU evaluation. @@ -81,11 +155,7 @@ def run_mmmu_eval( OPENAI_API_KEY=self.api_key, OPENAI_API_BASE=f"{self.base_url}/v1", ): - subprocess.run( - cmd, - check=True, - timeout=3600, - ) + _run_lmms_eval_with_retry(cmd) def test_mmmu(self: CustomTestCase): """Run MMMU evaluation test.""" @@ -209,11 +279,7 @@ def run_mmmu_eval( *self.mmmu_args, ] - subprocess.run( - cmd, - check=True, - timeout=3600, - ) + _run_lmms_eval_with_retry(cmd) def _run_vlm_mmmu_test( self,