diff --git a/docs/accuracy.md b/docs/accuracy.md index f5588c9f..98b69b46 100644 --- a/docs/accuracy.md +++ b/docs/accuracy.md @@ -1,6 +1,6 @@ # Accuracy Benchmarks -In srt-slurm, users can run different accuracy benchmarks by setting the benchmark section in the config yaml file. Supported benchmarks include `mmlu`, `gpqa` and `longbenchv2`. +In srt-slurm, users can run different accuracy benchmarks by setting the benchmark section in the config yaml file. Supported benchmarks include `mmlu`, `gpqa`, `longbenchv2`, and `lm-eval`. ## Table of Contents @@ -14,6 +14,7 @@ In srt-slurm, users can run different accuracy benchmarks by setting the benchma - [Example: Quick Validation](#example-quick-validation) - [Output](#output) - [Important Notes](#important-notes) +- [lm-eval (InferenceX)](#lm-eval-inferencex) --- @@ -191,3 +192,84 @@ The output includes per-category scores and aggregate metrics: 4. **Categories**: Running specific categories is useful for targeted validation (e.g., just testing summarization capabilities) +## lm-eval (InferenceX) + +The `lm-eval` benchmark runner integrates [EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) via InferenceX's `benchmark_lib.sh`. Unlike the built-in benchmarks above, this runner sources evaluation logic from an external InferenceX workspace mounted at `/infmax-workspace`. + +This is used by InferenceX CI to run evals such as GSM8K and GPQA against NVIDIA multi-node disaggregated deployments on GB200, GB300, B200, B300, H100, and H200. AMD MI355X multi-node evals are handled by InferenceX's upstreamed AMD Slurm path, not by this srt-slurm runner. + +In InferenceX CI, recipes normally keep their throughput benchmark configuration. `do_sweep.py` invokes the registered `lm-eval` runner as a post-step when `RUN_EVAL=true`, or as the only benchmark-like step when `EVAL_ONLY=true`. There is no separate `infmax-eval` benchmark type. + +### How it works + +1. `RuntimeContext` mounts the host path from `INFMAX_WORKSPACE` at `/infmax-workspace` inside the Slurm container. +2. `do_sweep.py` starts infrastructure, workers, and the frontend for the normal recipe topology. +3. For `EVAL_ONLY=true`, `do_sweep.py` skips the throughput benchmark stage and runs `_run_post_eval()` directly after frontend startup. +4. `_run_post_eval()` waits for the OpenAI-compatible endpoint on port 8000 and, in eval-only mode, performs the full `wait_for_model()` health check for the configured prefill/decode or aggregated topology. +5. `_run_post_eval()` launches the registered `lm-eval` runner on the head node and passes through InferenceX metadata such as framework, precision, sequence length, prefill/decode topology, and eval concurrency. +6. The runner script (`benchmarks/scripts/lm-eval/bench.sh`) uses `MODEL_NAME` from `do_sweep.py`, or auto-discovers the served model from `/v1/models` as a fallback. +7. The runner sources `/infmax-workspace/benchmarks/benchmark_lib.sh`, runs `run_eval --framework lm-eval`, and calls `append_lm_eval_summary`. +8. Eval artifacts are copied to `/logs/eval_results/` for InferenceX launcher-side artifact pickup. + +### EVAL_ONLY mode + +srt-slurm supports an `EVAL_ONLY` mode for CI jobs that should only validate accuracy. This is controlled by environment variables from the InferenceX workflow: + +| Env var | Description | +|---------|-------------| +| `EVAL_ONLY` | Set to `true` to skip the throughput benchmark stage and run eval only | +| `RUN_EVAL` | Set to `true` to run eval after the throughput benchmark completes | +| `EVAL_CONC` | Concurrent requests for lm-eval, normally set by InferenceX from the generated `eval-conc` value | +| `INFMAX_WORKSPACE` | Host path to the InferenceX checkout that should be mounted at `/infmax-workspace` | +| `MODEL_NAME` | Served model alias for OpenAI-compatible requests; set by `do_sweep.py` from `config.served_model_name` | + +When `EVAL_ONLY=true`: +- Stage 4 skips the throughput benchmark entirely. No throughput result JSON is expected from srt-slurm. +- The eval path uses the full `wait_for_model()` health check before starting lm-eval. +- `_run_post_eval()` launches the `lm-eval` runner and returns its exit code. +- Eval failure is fatal because eval is the only purpose of the job. + +When `RUN_EVAL=true` (without `EVAL_ONLY`): +- Throughput benchmark runs normally +- After benchmark completes successfully, eval runs as a post-step +- Eval failure is non-fatal; the benchmark job still succeeds if throughput passed + +### Environment variables + +The following env vars are passed through to the lm-eval runner container: + +| Env var | Purpose | +|---------|---------| +| `RUN_EVAL`, `EVAL_ONLY`, `IS_MULTINODE` | Control whether eval runs and how InferenceX classifies the artifact | +| `FRAMEWORK`, `PRECISION`, `MODEL_PREFIX`, `RUNNER_TYPE`, `SPEC_DECODING` | Benchmark identity metadata for `meta_env.json` | +| `ISL`, `OSL`, `RESULT_FILENAME` | Sequence length and result-file metadata | +| `MODEL`, `MODEL_PATH`, `MODEL_NAME` | Model metadata and the served model alias used for requests | +| `MAX_MODEL_LEN`, `EVAL_MAX_MODEL_LEN` | Context-length metadata used by InferenceX eval helpers when available | +| `PREFILL_TP`, `PREFILL_EP`, `PREFILL_NUM_WORKERS`, `PREFILL_DP_ATTN` | Prefill-side topology metadata | +| `DECODE_TP`, `DECODE_EP`, `DECODE_NUM_WORKERS`, `DECODE_DP_ATTN` | Decode-side topology metadata | +| `EVAL_CONC`, `EVAL_CONCURRENT_REQUESTS` | Eval concurrency controls | + +The runner maps srt-slurm's `PREFILL_DP_ATTN` and `DECODE_DP_ATTN` names to InferenceX's `PREFILL_DP_ATTENTION` and `DECODE_DP_ATTENTION` names before calling `append_lm_eval_summary`. This is required for multi-node summary tables to preserve prefill/decode DPA state. + +### Concurrency + +Eval concurrency is ultimately read by InferenceX's `benchmark_lib.sh` from `EVAL_CONCURRENT_REQUESTS`. The runner script sets that value from `EVAL_CONC` when present, preserves an existing `EVAL_CONCURRENT_REQUESTS` otherwise, and falls back to `256` only if neither variable is set: + +```bash +export EVAL_CONCURRENT_REQUESTS="${EVAL_CONC:-${EVAL_CONCURRENT_REQUESTS:-256}}" +``` + +The InferenceX workflow sets `EVAL_CONC` from the generated `eval-conc` value. For multi-node configs, InferenceX selects the `8k1k` entry with the highest max eligible concurrency for each `(model, runner, framework, precision, spec-decoding, prefill-dp-attn, decode-dp-attn)` group, then sets `eval-conc` to the upper median of that config's eligible concurrency list. If `EVAL_CONC` is not set in the environment, `do_sweep.py` falls back to the max of the recipe benchmark concurrency list. + +### Output + +Eval artifacts are written to `/logs/eval_results/` inside the container: +- `meta_env.json` - metadata used by InferenceX aggregation and summary tables +- `results*.json` - lm-eval scores per task +- `sample*.jsonl` - per-sample outputs + +These are collected by the InferenceX NVIDIA launch scripts and uploaded as workflow artifacts. In eval-only mode the InferenceX workflow expects eval artifacts, not throughput benchmark artifacts. + +### Intricacies +1. Eval floor of 16 + - There is 1 sweep config of conc: [1], which causes evals to take >4hrs to complete. diff --git a/src/srtctl/benchmarks/__init__.py b/src/srtctl/benchmarks/__init__.py index 3a2d6449..088617a6 100644 --- a/src/srtctl/benchmarks/__init__.py +++ b/src/srtctl/benchmarks/__init__.py @@ -4,7 +4,7 @@ """Benchmark runners for srtctl.""" # Import runners to trigger registration -from srtctl.benchmarks import gpqa, gsm8k, longbenchv2, mmlu, mooncake_router, router, sa_bench, sglang_bench +from srtctl.benchmarks import gpqa, gsm8k, lm_eval, longbenchv2, mmlu, mooncake_router, router, sa_bench, sglang_bench from srtctl.benchmarks.base import ( BenchmarkRunner, get_runner, @@ -18,6 +18,7 @@ "list_benchmarks", "register_benchmark", # Runners + "lm_eval", "sa_bench", "sglang_bench", "mmlu", diff --git a/src/srtctl/benchmarks/lm_eval.py b/src/srtctl/benchmarks/lm_eval.py new file mode 100644 index 00000000..c63ec097 --- /dev/null +++ b/src/srtctl/benchmarks/lm_eval.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 SemiAnalysis LLC. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""lm-eval benchmark runner for InferenceX evals.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from srtctl.benchmarks.base import SCRIPTS_DIR, BenchmarkRunner, register_benchmark + +if TYPE_CHECKING: + from srtctl.core.runtime import RuntimeContext + from srtctl.core.schema import SrtConfig + + +@register_benchmark("lm-eval") +class LMEvalRunner(BenchmarkRunner): + """lm-eval accuracy evaluation using InferenceX benchmark_lib. + + Runs lm-eval via the InferenceX benchmark_lib.sh harness, + which handles task selection, result collection, and summary generation. + """ + + @property + def name(self) -> str: + return "lm-eval" + + @property + def script_path(self) -> str: + return "/srtctl-benchmarks/lm-eval/bench.sh" + + @property + def local_script_dir(self) -> str: + return str(SCRIPTS_DIR / "lm-eval") + + def validate_config(self, config: SrtConfig) -> list[str]: + # lm-eval has sensible defaults + return [] + + def build_command( + self, + config: SrtConfig, + runtime: RuntimeContext, + ) -> list[str]: + endpoint = f"http://localhost:{runtime.frontend_port}" + # Always use the container mount path, not the host path. + # INFMAX_WORKSPACE env var contains the host path (used for mount setup + # in runtime.py), but inside the container it's at /infmax-workspace. + infmax_workspace = "/infmax-workspace" + + return [ + "bash", + self.script_path, + endpoint, + infmax_workspace, + ] diff --git a/src/srtctl/benchmarks/scripts/lm-eval/bench.sh b/src/srtctl/benchmarks/scripts/lm-eval/bench.sh new file mode 100755 index 00000000..a10e4e7d --- /dev/null +++ b/src/srtctl/benchmarks/scripts/lm-eval/bench.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 SemiAnalysis LLC. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# lm-eval accuracy evaluation using InferenceX benchmark_lib +# Expects: endpoint [infmax_workspace] + +set -e + +ENDPOINT=$1 +INFMAX_WORKSPACE=${2:-/infmax-workspace} + +# Extract HOST and PORT from endpoint (e.g., http://localhost:8000) +HOST=$(echo "$ENDPOINT" | sed -E 's|https?://||; s|:.*||') +PORT=$(echo "$ENDPOINT" | sed -E 's|.*:([0-9]+).*|\1|') + +echo "lm-eval Config: endpoint=${ENDPOINT}; host=${HOST}; port=${PORT}; workspace=${INFMAX_WORKSPACE}" + +# Auto-discover the served model name from /v1/models if MODEL_NAME is not set. +# This ensures we use the exact name the server recognizes, regardless of what +# $MODEL (the HuggingFace ID from the workflow) is set to. +if [[ -z "${MODEL_NAME:-}" ]]; then + DISCOVERED_MODEL=$(curl -sf "${ENDPOINT}/v1/models" 2>/dev/null \ + | python3 -c "import sys,json; d=json.load(sys.stdin); print(d['data'][0]['id'])" 2>/dev/null || true) + if [[ -n "$DISCOVERED_MODEL" ]]; then + export MODEL_NAME="$DISCOVERED_MODEL" + echo "Auto-discovered MODEL_NAME from /v1/models: ${MODEL_NAME}" + else + echo "WARNING: Could not discover model name from /v1/models, using MODEL_NAME=${MODEL_NAME:-$MODEL}" + fi +else + echo "Using MODEL_NAME from environment: ${MODEL_NAME}" +fi + +# cd to workspace so that relative paths (e.g., utils/evals/*.yaml) resolve +cd "${INFMAX_WORKSPACE}" + +# Source the InferenceX benchmark library +source "${INFMAX_WORKSPACE}/benchmarks/benchmark_lib.sh" + +# Run lm-eval via benchmark_lib +# EVAL_CONC is set by the InferenceX workflow (median of conc list). +# benchmark_lib reads concurrency from EVAL_CONCURRENT_REQUESTS env var. +export EVAL_CONCURRENT_REQUESTS="${EVAL_CONC:-${EVAL_CONCURRENT_REQUESTS:-256}}" +echo "Running lm-eval with concurrent-requests=${EVAL_CONCURRENT_REQUESTS}..." +eval_rc=0 +run_eval --framework lm-eval --port "$PORT" || eval_rc=$? + +# Derive metadata env vars that append_lm_eval_summary needs but do_sweep.py +# does not pass directly (it passes PREFILL_TP/EP/etc, not TP/EP_SIZE/CONC). +export IS_MULTINODE="${IS_MULTINODE:-true}" +export TP="${TP:-${PREFILL_TP:-1}}" +export CONC="${CONC:-${EVAL_CONC:-${EVAL_CONCURRENT_REQUESTS:-1}}}" +export EP_SIZE="${EP_SIZE:-${PREFILL_EP:-1}}" +export DP_ATTENTION="${DP_ATTENTION:-${PREFILL_DP_ATTN:-false}}" +# Remap srt-slurm's DP_ATTN names to InferenceX's DP_ATTENTION names +export PREFILL_DP_ATTENTION="${PREFILL_DP_ATTENTION:-${PREFILL_DP_ATTN:-${DP_ATTENTION:-false}}}" +export DECODE_DP_ATTENTION="${DECODE_DP_ATTENTION:-${DECODE_DP_ATTN:-${DP_ATTENTION:-false}}}" + +# Generate the lm-eval summary +echo "Generating lm-eval summary..." +append_lm_eval_summary || true + +# Copy eval artifacts to /logs/eval_results/ +mkdir -p /logs/eval_results +echo "Copying eval artifacts to /logs/eval_results/..." +cp -v meta_env.json /logs/eval_results/ 2>/dev/null || true +cp -v results*.json /logs/eval_results/ 2>/dev/null || true +cp -v sample*.jsonl /logs/eval_results/ 2>/dev/null || true + +if [[ "$eval_rc" -ne 0 ]]; then + echo "lm-eval evaluation failed with exit code ${eval_rc}" + exit "$eval_rc" +fi + +echo "lm-eval evaluation complete" diff --git a/src/srtctl/cli/do_sweep.py b/src/srtctl/cli/do_sweep.py index ff6eaa91..77b79ac5 100644 --- a/src/srtctl/cli/do_sweep.py +++ b/src/srtctl/cli/do_sweep.py @@ -18,6 +18,7 @@ import os import sys import threading +import time from dataclasses import dataclass from pathlib import Path @@ -179,6 +180,118 @@ def _print_connection_info(self) -> None: logger.info("=" * 60) logger.info("") + def _run_post_eval(self, stop_event: threading.Event) -> int: + """Run lm-eval after the main benchmark completes (or directly in eval-only mode).""" + from srtctl.benchmarks import get_runner + from srtctl.core.health import wait_for_model + + # In eval-only mode the benchmark health check was skipped, so do the + # full model-ready wait here. In post-benchmark mode a quick port + # check is sufficient since the server already served traffic. + if os.environ.get("EVAL_ONLY", "false").lower() == "true": + r = self.config.resources + n_prefill = 0 if r.num_agg > 0 else r.num_prefill + n_decode = r.num_agg if r.num_agg > 0 else r.num_decode + hc = self.config.health_check + logger.info("EVAL_ONLY: Waiting for server health before eval...") + if not wait_for_model( + host=self.runtime.nodes.head, + port=8000, + n_prefill=n_prefill, + n_decode=n_decode, + poll_interval=float(hc.interval_seconds), + timeout=float(hc.max_attempts * hc.interval_seconds), + report_every=60.0, + frontend_type=self.config.frontend.type, + stop_event=stop_event, + ): + logger.error("Server did not become healthy for eval") + return 1 + else: + if not wait_for_port(self.runtime.nodes.head, 8000, timeout=30): + logger.error("Server health check failed before eval - skipping") + return 1 + + try: + runner = get_runner("lm-eval") + except ValueError as e: + logger.error("lm-eval runner not available: %s", e) + return 1 + + eval_log = self.runtime.log_dir / "eval.out" + cmd = runner.build_command(self.config, self.runtime) + + logger.info("Eval command: %s", " ".join(cmd)) + logger.info("Eval log: %s", eval_log) + + # Pass through eval-related env vars. InferenceX writes multi-node + # metadata from these variables in append_lm_eval_summary(). + env_to_set = {} + for var in [ + "RUN_EVAL", + "EVAL_ONLY", + "IS_MULTINODE", + "FRAMEWORK", + "PRECISION", + "MODEL_PREFIX", + "RUNNER_TYPE", + "RESULT_FILENAME", + "SPEC_DECODING", + "ISL", + "OSL", + "MODEL", + "MODEL_PATH", + "MAX_MODEL_LEN", + "EVAL_MAX_MODEL_LEN", + "PREFILL_TP", + "PREFILL_EP", + "PREFILL_DP_ATTN", + "PREFILL_NUM_WORKERS", + "DECODE_TP", + "DECODE_EP", + "DECODE_DP_ATTN", + "DECODE_NUM_WORKERS", + ]: + val = os.environ.get(var) + if val: + env_to_set[var] = val + + # Set MODEL_NAME to the served model name so lm-eval uses the correct + # name for API requests. Without this, benchmark_lib.sh falls back to + # $MODEL (the HuggingFace ID) which the server doesn't recognize. + env_to_set["MODEL_NAME"] = self.config.served_model_name + logger.info("Eval MODEL_NAME: %s", env_to_set["MODEL_NAME"]) + + # Use EVAL_CONC from workflow (median chosen by InferenceX mark_eval_entries), + # falling back to max of benchmark concurrency list. + eval_conc = os.environ.get("EVAL_CONC") + if eval_conc: + env_to_set["EVAL_CONC"] = eval_conc + logger.info("Eval concurrency (from workflow): %s", eval_conc) + else: + conc_list = self.config.benchmark.get_concurrency_list() + if conc_list: + env_to_set["EVAL_CONC"] = str(max(conc_list)) + logger.info("Eval concurrency (max of %s): %s", conc_list, env_to_set["EVAL_CONC"]) + + proc = start_srun_process( + command=cmd, + nodelist=[self.runtime.nodes.head], + output=str(eval_log), + container_image=str(self.runtime.container_image), + container_mounts=self.runtime.container_mounts, + env_to_set=env_to_set, + ) + + while proc.poll() is None: + if stop_event.is_set(): + logger.info("Stop requested, terminating eval") + proc.terminate() + return 1 + time.sleep(1) + + return proc.returncode or 0 + def run(self) -> int: """Run the complete sweep.""" # Create status reporter (fire-and-forget, no-op if not configured) @@ -221,8 +334,27 @@ def run(self) -> int: self._print_connection_info() - # Stage 4: Benchmark (status reported AFTER health check passes) - exit_code = self.run_benchmark(registry, stop_event, reporter) + if os.environ.get("EVAL_ONLY", "false").lower() == "true": + reporter.report(JobStatus.BENCHMARK, JobStage.BENCHMARK, "Running eval-only evaluation") + logger.info("EVAL_ONLY=true: Skipping benchmark stage and running lm-eval evaluation...") + exit_code = self._run_post_eval(stop_event) + if exit_code != 0: + logger.error("Eval-only evaluation failed with exit code %d", exit_code) + else: + logger.info("Eval-only evaluation completed successfully") + else: + # Stage 4: Benchmark (status reported AFTER health check passes) + exit_code = self.run_benchmark(registry, stop_event, reporter) + + # Stage 5: Post-benchmark eval (optional, non-fatal) + if os.environ.get("RUN_EVAL", "false").lower() == "true" and exit_code == 0: + reporter.report(JobStatus.BENCHMARK, JobStage.BENCHMARK, "Running post-benchmark evaluation") + logger.info("RUN_EVAL=true: Running post-benchmark lm-eval evaluation...") + eval_exit = self._run_post_eval(stop_event) + if eval_exit != 0: + logger.warning("Eval failed with exit code %d (benchmark result is still valid)", eval_exit) + else: + logger.info("Post-benchmark eval completed successfully") except Exception as e: logger.exception("Error during sweep: %s", e) diff --git a/src/srtctl/core/runtime.py b/src/srtctl/core/runtime.py index 3e68bdd5..31195ed3 100644 --- a/src/srtctl/core/runtime.py +++ b/src/srtctl/core/runtime.py @@ -231,6 +231,14 @@ def from_config( host_path, container_path = mount_spec.split(":", 1) container_mounts[Path(host_path).resolve()] = Path(container_path) + # Mount InferenceX workspace if available (for lm-eval support). + # Skip exists() check: the orchestrator runs on the SLURM head node + # where the GH Actions workspace path may not be directly accessible, + # but it IS accessible from compute nodes via shared filesystem. + infmax_ws = os.environ.get("INFMAX_WORKSPACE") + if infmax_ws: + container_mounts[Path(infmax_ws)] = Path("/infmax-workspace") + # Add FormattablePath mounts from config.container_mounts # These need to be expanded with the runtime context, so we create a # temporary context first and then update diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 261020c7..c15759b2 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -193,6 +193,62 @@ def test_build_command_includes_tokenizer_path(self): assert cmd[7] == "/model" # tokenizer path +class TestLMEvalRunner: + """Test LM-Eval runner.""" + + def test_registry_includes_lm_eval(self): + """lm-eval is in the benchmark registry.""" + assert "lm-eval" in list_benchmarks() + + def test_get_runner(self): + """Can get lm-eval runner.""" + runner = get_runner("lm-eval") + assert runner.name == "lm-eval" + + def test_script_path(self): + """Script path points to lm-eval bench.sh.""" + runner = get_runner("lm-eval") + assert "lm-eval/bench.sh" in runner.script_path + + def test_local_script_dir(self): + """Local script dir points to lm-eval scripts.""" + runner = get_runner("lm-eval") + assert runner.local_script_dir.endswith("lm-eval") + + def test_validate_config_always_valid(self): + """lm-eval accepts any config.""" + from srtctl.benchmarks.lm_eval import LMEvalRunner + from srtctl.core.schema import BenchmarkConfig, ModelConfig, ResourceConfig, SrtConfig + + runner = LMEvalRunner() + config = SrtConfig( + name="test", + model=ModelConfig(path="/model", container="/image", precision="fp4"), + resources=ResourceConfig(gpu_type="h100"), + benchmark=BenchmarkConfig(type="sa-bench"), + ) + assert runner.validate_config(config) == [] + + def test_build_command(self): + """build_command returns correct bash command.""" + from unittest.mock import MagicMock + + from srtctl.benchmarks.lm_eval import LMEvalRunner + + runner = LMEvalRunner() + runtime = MagicMock() + runtime.frontend_port = 8000 + + config = MagicMock() + cmd = runner.build_command(config, runtime) + assert cmd == [ + "bash", + "/srtctl-benchmarks/lm-eval/bench.sh", + "http://localhost:8000", + "/infmax-workspace", + ] + + class TestScriptsExist: """Test that benchmark scripts exist.""" @@ -209,3 +265,365 @@ def test_mmlu_script_exists(self): """MMLU script exists.""" script = SCRIPTS_DIR / "mmlu" / "bench.sh" assert script.exists() + + +class TestRunPostEval: + """Test SweepOrchestrator._run_post_eval method.""" + + @staticmethod + def _make_orchestrator(): + """Create a SweepOrchestrator with mocked config/runtime.""" + from pathlib import Path + + from srtctl.cli.do_sweep import SweepOrchestrator + from srtctl.core.runtime import Nodes, RuntimeContext + from srtctl.core.schema import ( + BenchmarkConfig, + FrontendConfig, + HealthCheckConfig, + ModelConfig, + ResourceConfig, + SrtConfig, + ) + + config = SrtConfig( + name="test", + model=ModelConfig(path="/model/test-model", container="/image", precision="fp4"), + resources=ResourceConfig( + gpu_type="h100", + gpus_per_node=8, + prefill_nodes=1, + decode_nodes=2, + prefill_workers=1, + decode_workers=2, + ), + benchmark=BenchmarkConfig(type="sa-bench", isl=1024, osl=1024, concurrencies="128x256x512"), + health_check=HealthCheckConfig(max_attempts=3, interval_seconds=1), + frontend=FrontendConfig(type="dynamo"), + ) + runtime = RuntimeContext( + job_id="12345", + run_name="test-run", + nodes=Nodes(head="node0", bench="node0", infra="node0", worker=("node0", "node1", "node2")), + head_node_ip="10.0.0.1", + infra_node_ip="10.0.0.1", + log_dir=Path("/tmp/logs"), + model_path=Path("/model/test-model"), + container_image=Path("/path/to/container.sqsh"), + gpus_per_node=8, + network_interface=None, + container_mounts={}, + environment={}, + ) + return SweepOrchestrator(config=config, runtime=runtime) + + def test_post_benchmark_port_check_fails(self): + """Returns 1 when port check fails in post-benchmark mode.""" + import os + import threading + from unittest.mock import patch + + orch = self._make_orchestrator() + stop = threading.Event() + with patch.dict(os.environ, {"EVAL_ONLY": "false"}, clear=False): + with patch("srtctl.cli.do_sweep.wait_for_port", return_value=False): + result = orch._run_post_eval(stop) + assert result == 1 + + def test_eval_only_health_check_fails(self): + """Returns 1 when health check fails in eval-only mode.""" + import os + import threading + from unittest.mock import patch + + orch = self._make_orchestrator() + stop = threading.Event() + with patch.dict(os.environ, {"EVAL_ONLY": "true"}, clear=False): + with patch("srtctl.core.health.wait_for_model", return_value=False): + result = orch._run_post_eval(stop) + assert result == 1 + + def test_runner_not_available(self): + """Returns 1 when lm-eval runner is not registered.""" + import os + import threading + from unittest.mock import patch + + orch = self._make_orchestrator() + stop = threading.Event() + with patch.dict(os.environ, {"EVAL_ONLY": "false"}, clear=False): + with patch("srtctl.cli.do_sweep.wait_for_port", return_value=True): + with patch("srtctl.benchmarks.get_runner", side_effect=ValueError("not found")): + result = orch._run_post_eval(stop) + assert result == 1 + + def test_successful_eval(self): + """Returns 0 when eval completes successfully.""" + import os + import threading + from unittest.mock import MagicMock, patch + + orch = self._make_orchestrator() + stop = threading.Event() + + mock_proc = MagicMock() + mock_proc.poll.side_effect = [None, 0] + mock_proc.returncode = 0 + + with patch.dict(os.environ, {"EVAL_ONLY": "false"}, clear=False): + with patch("srtctl.cli.do_sweep.wait_for_port", return_value=True): + with patch("srtctl.cli.do_sweep.start_srun_process", return_value=mock_proc): + result = orch._run_post_eval(stop) + assert result == 0 + + def test_eval_only_successful(self): + """Returns 0 in eval-only mode when health check and eval succeed.""" + import os + import threading + from unittest.mock import MagicMock, patch + + orch = self._make_orchestrator() + stop = threading.Event() + + mock_proc = MagicMock() + mock_proc.poll.side_effect = [None, 0] + mock_proc.returncode = 0 + + with patch.dict(os.environ, {"EVAL_ONLY": "true"}, clear=False): + with patch("srtctl.core.health.wait_for_model", return_value=True): + with patch("srtctl.cli.do_sweep.start_srun_process", return_value=mock_proc): + result = orch._run_post_eval(stop) + assert result == 0 + + def test_env_var_passthrough(self): + """Eval env vars are passed through to srun.""" + import os + import threading + from unittest.mock import MagicMock, patch + + orch = self._make_orchestrator() + stop = threading.Event() + + mock_proc = MagicMock() + mock_proc.poll.return_value = 0 + mock_proc.returncode = 0 + + env_vars = { + "EVAL_ONLY": "false", + "RUN_EVAL": "true", + "FRAMEWORK": "sglang", + "PRECISION": "fp4", + "MODEL": "test-model", + } + + captured_kwargs = {} + + def capture_srun(**kwargs): + captured_kwargs.update(kwargs) + return mock_proc + + with patch.dict(os.environ, env_vars, clear=False): + with patch("srtctl.cli.do_sweep.wait_for_port", return_value=True): + with patch("srtctl.cli.do_sweep.start_srun_process", side_effect=capture_srun): + orch._run_post_eval(stop) + + env_to_set = captured_kwargs["env_to_set"] + assert env_to_set["RUN_EVAL"] == "true" + assert env_to_set["FRAMEWORK"] == "sglang" + assert env_to_set["PRECISION"] == "fp4" + assert env_to_set["MODEL"] == "test-model" + assert env_to_set["MODEL_NAME"] == "test-model" + + def test_eval_conc_from_env(self): + """EVAL_CONC from env takes priority over benchmark concurrencies.""" + import os + import threading + from unittest.mock import MagicMock, patch + + orch = self._make_orchestrator() + stop = threading.Event() + + mock_proc = MagicMock() + mock_proc.poll.return_value = 0 + mock_proc.returncode = 0 + + captured_kwargs = {} + + def capture_srun(**kwargs): + captured_kwargs.update(kwargs) + return mock_proc + + with patch.dict(os.environ, {"EVAL_ONLY": "false", "EVAL_CONC": "64"}, clear=False): + with patch("srtctl.cli.do_sweep.wait_for_port", return_value=True): + with patch("srtctl.cli.do_sweep.start_srun_process", side_effect=capture_srun): + orch._run_post_eval(stop) + + assert captured_kwargs["env_to_set"]["EVAL_CONC"] == "64" + + def test_eval_conc_fallback_to_max_concurrency(self): + """EVAL_CONC falls back to max of benchmark concurrencies.""" + import os + import threading + from unittest.mock import MagicMock, patch + + orch = self._make_orchestrator() + stop = threading.Event() + + mock_proc = MagicMock() + mock_proc.poll.return_value = 0 + mock_proc.returncode = 0 + + captured_kwargs = {} + + def capture_srun(**kwargs): + captured_kwargs.update(kwargs) + return mock_proc + + env = {"EVAL_ONLY": "false"} + # Remove EVAL_CONC if present + with patch.dict(os.environ, env, clear=False): + os.environ.pop("EVAL_CONC", None) + with patch("srtctl.cli.do_sweep.wait_for_port", return_value=True): + with patch("srtctl.cli.do_sweep.start_srun_process", side_effect=capture_srun): + orch._run_post_eval(stop) + + # concurrencies="128x256x512", max is 512 + assert captured_kwargs["env_to_set"]["EVAL_CONC"] == "512" + + def test_stop_event_terminates_eval(self): + """Stop event terminates the eval process.""" + import os + import threading + from unittest.mock import MagicMock, patch + + orch = self._make_orchestrator() + stop = threading.Event() + stop.set() + + mock_proc = MagicMock() + mock_proc.poll.return_value = None + + with patch.dict(os.environ, {"EVAL_ONLY": "false"}, clear=False): + with patch("srtctl.cli.do_sweep.wait_for_port", return_value=True): + with patch("srtctl.cli.do_sweep.start_srun_process", return_value=mock_proc): + result = orch._run_post_eval(stop) + + assert result == 1 + mock_proc.terminate.assert_called_once() + + +class TestSweepRunEvalIntegration: + """Test eval-related branches in SweepOrchestrator.run().""" + + @staticmethod + def _make_orchestrator(): + return TestRunPostEval._make_orchestrator() + + def test_run_eval_only_mode(self): + """EVAL_ONLY=true skips benchmark and runs _run_post_eval.""" + import os + from unittest.mock import MagicMock, patch + + orch = self._make_orchestrator() + + with patch.dict(os.environ, {"EVAL_ONLY": "true"}, clear=False): + with patch.object(orch, "start_head_infrastructure") as mock_head: + mock_head.return_value = MagicMock() + with patch.object(orch, "start_all_workers", return_value={}): + with patch.object(orch, "start_frontend", return_value=[]): + with patch.object(orch, "_run_post_eval", return_value=0) as mock_eval: + with patch.object(orch, "run_benchmark") as mock_bench: + with patch.object(orch, "run_postprocess"): + with patch("srtctl.cli.do_sweep.StatusReporter") as mock_reporter_cls: + mock_reporter_cls.from_config.return_value = MagicMock() + exit_code = orch.run() + + mock_eval.assert_called_once() + mock_bench.assert_not_called() + assert exit_code == 0 + + def test_run_with_post_benchmark_eval(self): + """RUN_EVAL=true runs benchmark then _run_post_eval.""" + import os + from unittest.mock import MagicMock, patch + + orch = self._make_orchestrator() + + with patch.dict(os.environ, {"EVAL_ONLY": "false", "RUN_EVAL": "true"}, clear=False): + with patch.object(orch, "start_head_infrastructure") as mock_head: + mock_head.return_value = MagicMock() + with patch.object(orch, "start_all_workers", return_value={}): + with patch.object(orch, "start_frontend", return_value=[]): + with patch.object(orch, "run_benchmark", return_value=0) as mock_bench: + with patch.object(orch, "_run_post_eval", return_value=0) as mock_eval: + with patch.object(orch, "run_postprocess"): + with patch("srtctl.cli.do_sweep.StatusReporter") as mock_reporter_cls: + mock_reporter_cls.from_config.return_value = MagicMock() + exit_code = orch.run() + + mock_bench.assert_called_once() + mock_eval.assert_called_once() + assert exit_code == 0 + + def test_run_eval_only_failure(self): + """EVAL_ONLY=true with eval failure returns non-zero exit code.""" + import os + from unittest.mock import MagicMock, patch + + orch = self._make_orchestrator() + + with patch.dict(os.environ, {"EVAL_ONLY": "true"}, clear=False): + with patch.object(orch, "start_head_infrastructure") as mock_head: + mock_head.return_value = MagicMock() + with patch.object(orch, "start_all_workers", return_value={}): + with patch.object(orch, "start_frontend", return_value=[]): + with patch.object(orch, "_run_post_eval", return_value=1): + with patch.object(orch, "run_postprocess"): + with patch("srtctl.cli.do_sweep.StatusReporter") as mock_reporter_cls: + mock_reporter_cls.from_config.return_value = MagicMock() + exit_code = orch.run() + + assert exit_code == 1 + + def test_run_post_benchmark_eval_failure_nonfatal(self): + """RUN_EVAL=true with eval failure still returns benchmark exit code 0.""" + import os + from unittest.mock import MagicMock, patch + + orch = self._make_orchestrator() + + with patch.dict(os.environ, {"EVAL_ONLY": "false", "RUN_EVAL": "true"}, clear=False): + with patch.object(orch, "start_head_infrastructure") as mock_head: + mock_head.return_value = MagicMock() + with patch.object(orch, "start_all_workers", return_value={}): + with patch.object(orch, "start_frontend", return_value=[]): + with patch.object(orch, "run_benchmark", return_value=0): + with patch.object(orch, "_run_post_eval", return_value=1): + with patch.object(orch, "run_postprocess"): + with patch("srtctl.cli.do_sweep.StatusReporter") as mock_reporter_cls: + mock_reporter_cls.from_config.return_value = MagicMock() + exit_code = orch.run() + + assert exit_code == 0 + + def test_run_eval_skipped_when_benchmark_fails(self): + """RUN_EVAL=true but benchmark fails: eval is skipped.""" + import os + from unittest.mock import MagicMock, patch + + orch = self._make_orchestrator() + + with patch.dict(os.environ, {"EVAL_ONLY": "false", "RUN_EVAL": "true"}, clear=False): + with patch.object(orch, "start_head_infrastructure") as mock_head: + mock_head.return_value = MagicMock() + with patch.object(orch, "start_all_workers", return_value={}): + with patch.object(orch, "start_frontend", return_value=[]): + with patch.object(orch, "run_benchmark", return_value=1): + with patch.object(orch, "_run_post_eval") as mock_eval: + with patch.object(orch, "run_postprocess"): + with patch("srtctl.cli.do_sweep.StatusReporter") as mock_reporter_cls: + mock_reporter_cls.from_config.return_value = MagicMock() + exit_code = orch.run() + + mock_eval.assert_not_called() + assert exit_code == 1 diff --git a/tests/test_configs.py b/tests/test_configs.py index 86d79cdb..0b4138d5 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -1382,3 +1382,113 @@ def test_agg_mode_no_disaggregation_flag(self): assert "--disaggregation-mode" not in cmd assert "--is-prefill-worker" not in cmd assert "--is-decode-worker" not in cmd + + +class TestInfmaxWorkspaceMount: + """Test that INFMAX_WORKSPACE env var creates a container mount.""" + + def test_infmax_workspace_mount_added(self, tmp_path): + """RuntimeContext includes /infmax-workspace mount when env var is set.""" + import os + import subprocess + from pathlib import Path + from unittest.mock import MagicMock, patch + + from srtctl.core.runtime import RuntimeContext + from srtctl.core.schema import ModelConfig, ResourceConfig, SrtConfig + + model_path = tmp_path / "model" + model_path.mkdir() + container_path = tmp_path / "container.sqsh" + container_path.touch() + + slurm_env = { + "SLURM_JOB_ID": "12345", + "SLURM_JOBID": "12345", + "SLURM_NODELIST": "gpu-[01-02]", + "SLURM_JOB_NUM_NODES": "2", + "SRTCTL_SOURCE_DIR": str(Path(__file__).parent.parent), + "INFMAX_WORKSPACE": "/actions/runner/workspace", + } + + def mock_scontrol(cmd, **kwargs): + if cmd[0] == "scontrol" and "hostnames" in cmd: + result = MagicMock() + result.stdout = "gpu-01\ngpu-02" + result.returncode = 0 + return result + raise subprocess.CalledProcessError(1, cmd) + + with patch.dict(os.environ, slurm_env): + with patch("subprocess.run", mock_scontrol): + with patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"): + config = SrtConfig( + name="test", + model=ModelConfig( + path=str(model_path), + container=str(container_path), + precision="fp8", + ), + resources=ResourceConfig( + gpu_type="h100", + gpus_per_node=8, + prefill_nodes=1, + decode_nodes=1, + ), + ) + runtime = RuntimeContext.from_config(config, job_id="12345") + + assert Path("/infmax-workspace") in runtime.container_mounts.values() + + def test_infmax_workspace_mount_not_added_without_env(self, tmp_path): + """RuntimeContext does not include /infmax-workspace without env var.""" + import os + import subprocess + from pathlib import Path + from unittest.mock import MagicMock, patch + + from srtctl.core.runtime import RuntimeContext + from srtctl.core.schema import ModelConfig, ResourceConfig, SrtConfig + + model_path = tmp_path / "model" + model_path.mkdir() + container_path = tmp_path / "container.sqsh" + container_path.touch() + + slurm_env = { + "SLURM_JOB_ID": "12345", + "SLURM_JOBID": "12345", + "SLURM_NODELIST": "gpu-[01-02]", + "SLURM_JOB_NUM_NODES": "2", + "SRTCTL_SOURCE_DIR": str(Path(__file__).parent.parent), + } + + def mock_scontrol(cmd, **kwargs): + if cmd[0] == "scontrol" and "hostnames" in cmd: + result = MagicMock() + result.stdout = "gpu-01\ngpu-02" + result.returncode = 0 + return result + raise subprocess.CalledProcessError(1, cmd) + + with patch.dict(os.environ, slurm_env): + os.environ.pop("INFMAX_WORKSPACE", None) + with patch("subprocess.run", mock_scontrol): + with patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"): + config = SrtConfig( + name="test", + model=ModelConfig( + path=str(model_path), + container=str(container_path), + precision="fp8", + ), + resources=ResourceConfig( + gpu_type="h100", + gpus_per_node=8, + prefill_nodes=1, + decode_nodes=1, + ), + ) + runtime = RuntimeContext.from_config(config, job_id="12345") + + assert Path("/infmax-workspace") not in runtime.container_mounts.values()