diff --git a/python/sglang/test/simple_eval_mgsm.py b/python/sglang/test/simple_eval_mgsm.py index 0b0b72a20f72..329d399c99d3 100644 --- a/python/sglang/test/simple_eval_mgsm.py +++ b/python/sglang/test/simple_eval_mgsm.py @@ -9,6 +9,8 @@ import re import urllib +import urllib.error +import urllib.request from typing import Optional from sglang.test import simple_eval_common as common @@ -112,16 +114,32 @@ def score_mgsm(target: str, prediction: str) -> bool: return target == prediction -def get_lang_examples(lang: str) -> list[dict[str, str]]: +def get_lang_examples( + lang: str, timeout: int = 60, max_retries: int = 3 +) -> list[dict[str, str]]: fpath = LANG_TO_FPATH[lang] examples = [] - with urllib.request.urlopen(fpath) as f: - for line in f.read().decode("utf-8").splitlines(): - inputs, targets = line.strip().split("\t") - if "." in targets: - raise ValueError(f"targets {targets} contains a decimal point.") - # targets = int(targets.replace(",", "")) - examples.append({"inputs": inputs, "targets": targets, "lang": lang}) + for attempt in range(max_retries): + try: + with urllib.request.urlopen(fpath, timeout=timeout) as f: + for line in f.read().decode("utf-8").splitlines(): + inputs, targets = line.strip().split("\t") + if "." in targets: + raise ValueError(f"targets {targets} contains a decimal point.") + examples.append( + {"inputs": inputs, "targets": targets, "lang": lang} + ) + return examples + except (urllib.error.URLError, TimeoutError) as e: + if attempt < max_retries - 1: + import time + + time.sleep(2**attempt) + else: + raise RuntimeError( + f"Failed to download MGSM data for {lang} from {fpath} " + f"after {max_retries} attempts: {e}" + ) from e return examples