diff --git a/src/srtctl/core/backend.py b/src/srtctl/core/backend.py index 3419ed41..d7d5c102 100644 --- a/src/srtctl/core/backend.py +++ b/src/srtctl/core/backend.py @@ -142,6 +142,25 @@ def generate_slurm_script(self, config_path: Path = None, timestamp: str = None) conc = benchmark_config.get("concurrencies") conc_str = "x".join(str(c) for c in conc) if isinstance(conc, list) else str(conc) parsable_config = f"{benchmark_config.get('isl')} {benchmark_config.get('osl')} {conc_str} {benchmark_config.get('req_rate', 'inf')}" + elif bench_type == "mmlu": + num_examples = benchmark_config.get("num_examples", 200) + max_tokens = benchmark_config.get("max_tokens", 2048) + repeat = benchmark_config.get("repeat", 8) + num_threads = benchmark_config.get("num_threads", 512) + parsable_config = f"{num_examples} {max_tokens} {repeat} {num_threads}" + elif bench_type == "gpqa": + num_examples = benchmark_config.get("num_examples", 198) + max_tokens = benchmark_config.get("max_tokens", 32768) + repeat = benchmark_config.get("repeat", 8) + num_threads = benchmark_config.get("num_threads", 128) + parsable_config = f"{num_examples} {max_tokens} {repeat} {num_threads}" + elif bench_type == "longbenchv2": + num_examples = benchmark_config.get("num_examples", None) + max_tokens = benchmark_config.get("max_tokens", 16384) + max_context_length = benchmark_config.get("max_context_length", 128000) + num_threads = benchmark_config.get("num_threads", 16) + categories = benchmark_config.get("categories", None) + parsable_config = f"{num_examples} {max_tokens} {max_context_length} {num_threads} {categories}" # Paths srtctl_root = Path(get_srtslurm_setting("srtctl_root") or Path(srtctl.__file__).parent.parent.parent)