diff --git a/README.md b/README.md index 7af1258e..cdbc62b1 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ make setup ARCH=aarch64 # or ARCH=x86_64 - [Parameter Sweeps](docs/sweeps.md) - Grid searches - [Profiling](docs/profiling.md) - Torch/nsys profiling - [Analyzing Results](docs/analyzing.md) - Dashboard and visualization +- [Accuracy Benchmarks](docs/accuracy.md) - Running accuracy benchmarks ## Commands diff --git a/docs/accuracy.md b/docs/accuracy.md index f5588c9f..9ea72c8e 100644 --- a/docs/accuracy.md +++ b/docs/accuracy.md @@ -27,7 +27,7 @@ For MMLU dataset, the benchmark section in yaml file can be modified in the foll benchmark: type: "mmlu" num_examples: 200 # Number of examples to run - max_tokens: 2048 # Max number of output tokens + max_tokens: 8192 # Max number of output tokens. repeat: 8 # Number of repetition num_threads: 512 # Number of parallel threads for running benchmark ``` @@ -40,18 +40,20 @@ srtctl apply -f config.yaml After finishing benchmarking, the `benchmark.out` will contain the results of accuracy: ``` ==================== -Repeat: 8, mean: 0.812 -Scores: ['0.790', '0.820', '0.800', '0.820', '0.820', '0.790', '0.820', '0.840'] +Repeat: 8, mean: 0.895 +Scores: ['0.905', '0.895', '0.900', '0.880', '0.905', '0.890', '0.890', '0.895'] ==================== Writing report to /tmp/mmlu_deepseek-ai_DeepSeek-R1.html -{'other': np.float64(0.9), 'other:std': np.float64(0.30000000000000004), 'score:std': np.float64(0.36660605559646725), 'stem': np.float64(0.8095238095238095), 'stem:std': np.float64(0.392676726249301), 'humanities': np.float64(0.7428571428571429), 'humanities:std': np.float64(0.4370588154508102), 'social_sciences': np.float64(0.9583333333333334), 'social_sciences:std': np.float64(0.19982631347136331), 'score': np.float64(0.84)} +{'other': np.float64(0.9361702127659575), 'other:std': np.float64(0.24444947432076722), 'score:std': np.float64(0.3065534211193866), 'stem': np.float64(0.9285714285714286), 'stem:std': np.float64(0.25753937681885636), 'humanities': np.float64(0.8064516129032258), 'humanities:std': np.float64(0.3950789907714804), 'social_sciences': np.float64(0.9387755102040817), 'social_sciences:std': np.float64(0.23974163519328023), 'score': np.float64(0.895)} Writing results to /tmp/mmlu_deepseek-ai_DeepSeek-R1.json -Total latency: 465.618 s -Score: 0.840 +Total latency: 754.457 s +Score: 0.895 Results saved to: /logs/accuracy/mmlu_deepseek-ai_DeepSeek-R1.json MMLU evaluation complete ``` +**Note: `max-tokens` should be large enough to reach expected accuracy. For deepseek-r1-fp4 model, `max-tokens=8192` can reach expected accuracy 0.895, while `max-tokens=2048` can only score at 0.81.** + ## GPQA For GPQA dataset, the benchmark section in yaml file can be modified in the following way: diff --git a/docs/profiling.md b/docs/profiling.md index b6666de9..68633c7e 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -66,7 +66,7 @@ profiling: profiling: type: "torch" # Required: "none", "torch", or "nsys" - # Traffic generator parameters (required when profiling is enabled) +# Traffic generator parameters (required when profiling is enabled) isl: 1024 # Input sequence length osl: 128 # Output sequence length concurrency: 24 # Batch size for profiling workload diff --git a/docs/sglang-router.md b/docs/sglang-router.md index 3dd5098e..44579405 100644 --- a/docs/sglang-router.md +++ b/docs/sglang-router.md @@ -177,6 +177,67 @@ The default bootstrap port is `30001` (matching most recipes). If you use a diff Workers listen on port `30000` by default. This is standard sglang behavior and doesn't need configuration. +## Debugging with SGLang Source Code + +When using sglang-router mode, you can mount and install sglang from source for debugging purposes. This is useful when you need to test local changes or debug issues in sglang itself. + +### Configuration + +Add `sglang_src_dir` to your recipe's `backend` section: + +```yaml +backend: + use_sglang_router: true + sglang_src_dir: "/path/to/your/local/sglang" +``` + +### How It Works + +1. Your local sglang directory is mounted to `/ext-sglang-src/` in the container +2. Before launching workers, the script runs: `pip install -e . --no-deps` +3. Workers use your local sglang code instead of the container's pre-installed version + +### Behavior + +**With `sglang_src_dir` set:** +- Mounts your local sglang source to `/ext-sglang-src/` +- Installs it in editable mode on all prefill/decode/aggregated workers +- Your local changes take effect immediately + +**Without `sglang_src_dir` (or empty):** +- No mount is added +- Installation step is skipped gracefully +- Uses the container's pre-installed sglang + +### Example + +```yaml +name: "debug-sglang-router" + +model: + path: "deepseek-r1-fp4" + container: "0.5.5.post2" + +backend: + use_sglang_router: true + sglang_src_dir: "/home/username/projects/sglang" # Your local sglang checkout + + sglang_config: + # ... your config +``` + +Then apply: +```bash +srtctl apply -f recipies/debug-sglang-router.yaml +``` + +### Notes + +- Only works with `use_sglang_router: true` (disaggregation mode) +- The source directory must exist on the host running srtctl +- Dependencies are NOT reinstalled (uses `--no-deps`), so the container must have compatible dependencies already installed +- Useful for iterative debugging without rebuilding containers + ## Complete Example Here's a full recipe using sglang router: diff --git a/examples/fp4-disagg-nsys-profiling.yaml b/examples/fp4-disagg-nsys-profiling.yaml new file mode 100644 index 00000000..83135c08 --- /dev/null +++ b/examples/fp4-disagg-nsys-profiling.yaml @@ -0,0 +1,123 @@ +name: "gb200-fp4-1p2d" + +model: + path: "dsfp4" + container: "0.5.5.post2" + precision: "fp4" + +resources: + gpu_type: "gb200" + prefill_nodes: 1 + decode_nodes: 2 + prefill_workers: 1 + decode_workers: 2 + gpus_per_node: 4 + +backend: + use_sglang_router: "true" + + prefill_environment: + SGLANG_LOG_FORWARD_ITERS: "1" + TORCH_DISTRIBUTED_DEFAULT_TIMEOUT: "1800" + PYTHONUNBUFFERED: "1" + DYN_SKIP_SGLANG_LOG_FORMATTING: "1" + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER: "0" + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK: "1" + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE: "100000" + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT: "100000" + SGLANG_DISAGGREGATION_WAITING_TIMEOUT: "100000" + SGLANG_DECODE_BOOTSTRAP_TIMEOUT: "1000" + #SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN: "1" + #SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2: "1" + MC_FORCE_MNNVL: "1" + NCCL_MNNVL_ENABLE: "1" + NCCL_CUMEM_ENABLE: "1" + SGLANG_MOONCAKE_CUSTOM_MEM_POOL: "True" + SGLANG_ENABLE_JIT_DEEPGEMM: "false" + SGLANG_ENABLE_FLASHINFER_GEMM: "true" #instead of SGLANG_FLASHINFER_FP4_GEMM_BACKEND + + decode_environment: + SGLANG_LOG_FORWARD_ITERS: "1" + TORCH_DISTRIBUTED_DEFAULT_TIMEOUT: "1800" + PYTHONUNBUFFERED: "1" + DYN_SKIP_SGLANG_LOG_FORMATTING: "1" + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER: "0" + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK: "1" + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE: "100000" + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT: "100000" + SGLANG_DISAGGREGATION_WAITING_TIMEOUT: "100000" + SGLANG_DECODE_BOOTSTRAP_TIMEOUT: "1000" + # SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN: "1" + # SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2: "1" + MC_FORCE_MNNVL: "1" + NCCL_MNNVL_ENABLE: "1" + NCCL_CUMEM_ENABLE: "1" + SGLANG_MOONCAKE_CUSTOM_MEM_POOL: "True" + SGLANG_ENABLE_JIT_DEEPGEMM: "false" + SGLANG_ENABLE_FLASHINFER_GEMM: "true" #instead of SGLANG_FLASHINFER_FP4_GEMM_BACKEND + + sglang_config: + prefill: + disaggregation-mode: "prefill" + served-model-name: "deepseek-ai/DeepSeek-R1" + model-path: "/model/" + trust-remote-code: true + disable-radix-cache: true + kv-cache-dtype: "fp8_e4m3" + attention-backend: "trtllm_mla" + quantization: "modelopt_fp4" + moe-runner-backend: "flashinfer_trtllm" + stream-interval: 10 + watchdog-timeout: 1000000 + context-length: 2200 + mem-fraction-static: 0.95 + max-total-tokens: 8192 + chunked-prefill-size: 8192 + cuda-graph-max-bs: 256 + max-running-requests: 512 + scheduler-recv-interval: 10 + enable-symm-mem: true + moe-dense-tp-size: 1 + load-balance-method: "round_robin" + disaggregation-bootstrap-port: 30001 + load-format: "dummy" + data-parallel-size: 1 + tensor-parallel-size: 4 + expert-parallel-size: 1 + + decode: + disaggregation-mode: "decode" + served-model-name: "deepseek-ai/DeepSeek-R1" + model-path: "/model/" + prefill-round-robin-balance: true + trust-remote-code: true + disable-radix-cache: true + kv-cache-dtype: "fp8_e4m3" + attention-backend: "trtllm_mla" + quantization: "modelopt_fp4" + moe-runner-backend: "flashinfer_trtllm" + disaggregation-bootstrap-port: 30001 + stream-interval: 10 + watchdog-timeout: 1000000 + context-length: 2200 + mem-fraction-static: 0.95 + load-format: "dummy" + chunked-prefill-size: 8192 + cuda-graph-max-bs: 256 + scheduler-recv-interval: 10 + enable-symm-mem: true + moe-dense-tp-size: 1 + tensor-parallel-size: 4 + expert-parallel-size: 1 + +profiling: + type: "nsys" + isl: 1024 + osl: 1024 + concurrency: 256 + prefill: + start_step: 60 + stop_step: 70 + decode: + start_step: 700 + stop_step: 730 \ No newline at end of file diff --git a/scripts/benchmarks/mmlu/bench.sh b/scripts/benchmarks/mmlu/bench.sh new file mode 100644 index 00000000..9450df9a --- /dev/null +++ b/scripts/benchmarks/mmlu/bench.sh @@ -0,0 +1,64 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# GPQA evaluation script using sglang.test.run_eval with mmlu + +head_node="localhost" +head_port=8000 +model_name="deepseek-ai/DeepSeek-R1" # Default model name + +# Parse arguments from SLURM job +n_prefill=$1 +n_decode=$2 +prefill_gpus=$3 +decode_gpus=$4 +num_examples=${5:-200} # Default: 200 +max_tokens=${6:-8192} # Default: 8192 +repeat=${7:-8} # Default: 8 +num_threads=${8:-512} # Default: 512 + +echo "MMLU Benchmark Config: num_examples=${num_examples}; max_tokens=${max_tokens}; repeat=${repeat}; num_threads=${num_threads}" + +# Source utilities for wait_for_model +source /scripts/utils/benchmark_utils.sh + +wait_for_model_timeout=1500 # 25 minutes +wait_for_model_check_interval=5 # check interval -> 5s +wait_for_model_report_interval=60 # wait_for_model report interval -> 60s + +wait_for_model $head_node $head_port $n_prefill $n_decode $wait_for_model_check_interval $wait_for_model_timeout $wait_for_model_report_interval + +# Create results directory +result_dir="/logs/accuracy" +mkdir -p $result_dir + +echo "Running MMLU evaluation..." + +# Set OPENAI_API_KEY if not set +if [ -z "$OPENAI_API_KEY" ]; then + export OPENAI_API_KEY="EMPTY" +fi + +# Run the evaluation +python3 -m sglang.test.run_eval \ + --base-url "http://${head_node}:${head_port}" \ + --model ${model_name} \ + --eval-name mmlu \ + --num-examples ${num_examples} \ + --max-tokens ${max_tokens} \ + --repeat ${repeat} \ + --num-threads ${num_threads} + +# Copy the result file from /tmp to our logs directory +# The result file is named mmlu_{model_name}.json +result_file=$(ls -t /tmp/mmlu_*.json 2>/dev/null | head -n1) + +if [ -f "$result_file" ]; then + cp "$result_file" "$result_dir/" + echo "Results saved to: $result_dir/$(basename $result_file)" +else + echo "Warning: Could not find result file in /tmp" +fi + +echo "MMLU evaluation complete" diff --git a/scripts/profiling/profile.sh b/scripts/profiling/profile.sh new file mode 100755 index 00000000..fd571055 --- /dev/null +++ b/scripts/profiling/profile.sh @@ -0,0 +1,194 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Torch profiling script for sglang.launch_server +# This script runs bench_one_batch_server with profiling enabled + +model_name="deepseek-ai/DeepSeek-R1" +head_node="${HEAD_NODE:-127.0.0.1}" +head_port="${HEAD_PORT:-8000}" + +# Parse arguments (same as sa-bench for consistency) +n_prefill=$1 +n_decode=$2 +prefill_gpus=$3 +decode_gpus=$4 +total_gpus=$5 + +echo "Torch Profiling Configuration:" +echo " Profiling dir: ${SGLANG_TORCH_PROFILER_DIR}" +echo " Prefill workers: ${n_prefill}" +echo " Decode workers: ${n_decode}" +echo " Prefill GPUs: ${prefill_gpus}" +echo " Decode GPUs: ${decode_gpus}" +echo " Total GPUs: ${total_gpus}" + +# Wait for server to be ready using inline wait function +wait_until_ready() { + local SERVER_URL="$1" + while true; do + status_code=$(curl -s -o /dev/null -w "%{http_code}" "${SERVER_URL}/health" || echo "000") + if [ "$status_code" -eq 200 ]; then + echo "Server ${SERVER_URL} is ready" + break + fi + echo "Server not ready yet (status: ${status_code}), waiting..." + top -b -n 1 | head -10 + PID=$(nvidia-smi --query-compute-apps=pid --format=csv,noheader -i 0 | tr -d ' ' | head -n1) + [ -n "$PID" ] && py-spy dump -s --pid $PID > /logs/py-spy-dump-${SLURM_NODEID:-0}.txt || echo "No GPU process found" + sleep 30 + done +} + +# Parse leader IP lists from environment (comma-separated) +IFS=',' read -r -a PREFILL_IPS <<< "${PROFILE_PREFILL_IPS:-}" +IFS=',' read -r -a DECODE_IPS <<< "${PROFILE_DECODE_IPS:-}" +IFS=',' read -r -a AGG_IPS <<< "${PROFILE_AGG_IPS:-}" + +wait_all_workers_ready() { + local ips=("$@") + for ip in "${ips[@]}"; do + if [[ -z "${ip}" ]]; then + continue + fi + echo "Waiting for worker at http://${ip}:30000 to be ready..." + wait_until_ready "http://${ip}:30000" + done +} + +if [[ "${#PREFILL_IPS[@]}" -gt 0 || "${#DECODE_IPS[@]}" -gt 0 || "${#AGG_IPS[@]}" -gt 0 ]]; then + echo "Waiting for all profiling workers to be ready..." + wait_all_workers_ready "${PREFILL_IPS[@]}" "${DECODE_IPS[@]}" "${AGG_IPS[@]}" +else + echo "Error: node ip not set for profiling." + exit 1 +fi + +echo "Waiting for serving endpoint at http://${head_node}:${head_port} to be ready..." +wait_until_ready "http://${head_node}:${head_port}" + +# Determine profiling parameters strictly from environment +PROFILE_STEPS_ARG="" +CLI_ARGS="" +[[ -n "${PROFILE_CONCURRENCY}" ]] && CLI_ARGS+=" --batch-size ${PROFILE_CONCURRENCY}" +# Require ISL/OSL to be provided; do not pass them as CLI args here +if [[ -z "${PROFILE_ISL}" || -z "${PROFILE_OSL}" ]]; then + echo "Error: isl and osl must be set for profiling." + exit 1 +fi +if [[ -z "${PROFILE_CONCURRENCY}" ]]; then + echo "Error: concurrency must be set for profiling." + exit 1 +fi + +get_phase_start_step() { + local phase="$1" + local var_name="PROFILE_${phase}_START_STEP" + local value="${!var_name}" + echo "${value}" +} + +get_phase_stop_step() { + local phase="$1" + local var_name="PROFILE_${phase}_STOP_STEP" + local value="${!var_name}" + echo "${value}" +} + + +echo "Running profiler..." +echo "$(date '+%Y-%m-%d %H:%M:%S')" + +# Create profiling output directory only when torch profiler dir is provided +ACTIVITIES="" +if [[ -n "${SGLANG_TORCH_PROFILER_DIR}" ]]; then + ACTIVITIES='["CPU", "GPU", "MEM"]' + mkdir -p "${SGLANG_TORCH_PROFILER_DIR}" 2>/dev/null || true + export SGLANG_TORCH_PROFILER_DIR=${SGLANG_TORCH_PROFILER_DIR} +else + ACTIVITIES='["CUDA_PROFILER"]' + mkdir -p "/logs/profiles" 2>/dev/null || true +fi + +set -x + +start_profile_on_worker() { + local ip="$1" + local start_step="$2" + local stop_step="$3" + if [[ -z "${ip}" ]]; then + return + fi + if [[ -z "${start_step}" ]]; then + echo "Warning: profiling start_step not set; defaulting to 0" + start_step=0 + fi + if [[ -z "${stop_step}" ]]; then + echo "Warning: profiling stop_step not set; defaulting to 50" + stop_step=50 + fi + local num_steps=$((stop_step - start_step)) + if [[ "${num_steps}" -le 0 ]]; then + echo "Error: invalid profiling step range: start_step=${start_step} stop_step=${stop_step}" + return 1 + fi + echo "Starting profiling on http://${ip}:30000" + curl -X POST "http://${ip}:30000/start_profile" -H "Content-Type: application/json" -d "{\"start_step\": \"${start_step}\", \"num_steps\": ${num_steps}, \"activities\": $ACTIVITIES}" +} + +slow_down_first_decode_worker() { + local ip="$1" + if [[ -z "${ip}" ]]; then + return + fi + echo "Slowing down first decode worker at http://${ip}:30000" + curl -sS -X POST "http://${ip}:30000/slow_down" -H "Content-Type: application/json" -d '{"forward_sleep_time": 120.0}' || true +} + +prefill_start_step="$(get_phase_start_step PREFILL)" +prefill_stop_step="$(get_phase_stop_step PREFILL)" +decode_start_step="$(get_phase_start_step DECODE)" +decode_stop_step="$(get_phase_stop_step DECODE)" +agg_start_step="$(get_phase_start_step AGG)" +agg_stop_step="$(get_phase_stop_step AGG)" + +for ip in "${PREFILL_IPS[@]}"; do + start_profile_on_worker "${ip}" "${prefill_start_step}" "${prefill_stop_step}" +done +# slow_down_first_decode_worker "${DECODE_IPS[0]}" +for ip in "${DECODE_IPS[@]}"; do + start_profile_on_worker "${ip}" "${decode_start_step}" "${decode_stop_step}" +done +for ip in "${AGG_IPS[@]}"; do + start_profile_on_worker "${ip}" "${agg_start_step}" "${agg_stop_step}" +done + + +python3 -m sglang.bench_serving \ +--backend sglang \ +--model ${model_name} \ +--host ${head_node} --port ${head_port} \ +--dataset-name random \ +--max-concurrency $PROFILE_CONCURRENCY \ +--num-prompts 128 \ +--random-input-len $PROFILE_ISL \ +--random-output-len $PROFILE_OSL \ +--random-range-ratio 1 \ +--warmup-request 0 + +pip install lm-eval tenacity > /dev/null +python -m lm_eval \ +--model local-completions \ +--tasks gsm8k \ +--model_args base_url=http://${head_node}:${head_port}/v1/completions,model=${model_name},tokenized_requests=False,tokenizer_backend=None,num_concurrent=${PROFILE_CONCURRENCY},timeout=6000,max_retries=1 \ +--limit 10 + +exit_code=$? +set +x + +echo "$(date '+%Y-%m-%d %H:%M:%S')" +echo "Torch profiling completed with exit code ${exit_code}" +echo "Profiling results saved to ${SGLANG_TORCH_PROFILER_DIR}" + +exit ${exit_code} diff --git a/scripts/templates/job_script_template_agg.j2 b/scripts/templates/job_script_template_agg.j2 new file mode 100755 index 00000000..f3ea79b4 --- /dev/null +++ b/scripts/templates/job_script_template_agg.j2 @@ -0,0 +1,383 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }} +#SBATCH --nodes={{ total_nodes }} +#SBATCH --ntasks={{ total_nodes }} +#SBATCH --ntasks-per-node=1 +{% if use_gpus_per_node_directive %} +#SBATCH --gpus-per-node={{ gpus_per_node }} +{% endif %} +#SBATCH --account={{ account }} +#SBATCH --time={{ time_limit }} +#SBATCH --output={{ log_dir_prefix }}/%j_{{ agg_workers }}A_{{ timestamp }}/log.out +#SBATCH --partition={{ partition }} + +# Constants +set -x +AGG_NODES={{ agg_nodes }} +AGG_WORKERS={{ agg_workers }} +TOTAL_NODES={{ total_nodes }} +GPUS_PER_NODE={{ gpus_per_node }} +TOTAL_GPUS=$((AGG_NODES * GPUS_PER_NODE)) +PREFILL_GPUS=0 +DECODE_GPUS=$TOTAL_GPUS +AGG_NODES_PER_WORKER=$((AGG_NODES / AGG_WORKERS)) +{% if log_dir_prefix.startswith('/') %} +LOG_DIR="{{ log_dir_prefix }}/${SLURM_JOB_ID}_{{ agg_workers }}A_{{ timestamp }}" +{% else %} +LOG_DIR="${SLURM_SUBMIT_DIR}/{{ log_dir_prefix }}/${SLURM_JOB_ID}_{{ agg_workers }}A_{{ timestamp }}" +{% endif %} +SCRIPT_DIR="${SLURM_SUBMIT_DIR}/scripts" +OUTPUT_DIR="${SLURM_SUBMIT_DIR}/outputs" +MODEL_DIR="{{ model_dir }}" +CONFIG_DIR="{{ config_dir }}" +CONTAINER_IMAGE="{{ container_image }}" +NETWORK_INTERFACE="{{ network_interface }}" +GPU_TYPE="{{ gpu_type | default('h100') }}" +set +x + +{% raw %} + +mkdir -p "${OUTPUT_DIR}" "${LOG_DIR}" + +# Source utility functions for robust IP discovery +source "${SCRIPT_DIR}/utils/slurm_utils.sh" + +nodes=($(scontrol show hostnames $SLURM_NODELIST)) +if [ ${#nodes[@]} -ne $TOTAL_NODES ]; then + echo "Error: Expected $TOTAL_NODES nodes but got ${#nodes[@]} nodes" + exit 1 +fi + +# Print node information +for i in "${!nodes[@]}"; do + echo "Node $i: ${nodes[$i]}" +done + +{% endraw %} +{% if enable_multiple_frontends %} +{% raw %} +# Multiple frontend architecture +# Node 0: nginx + aggregated worker shard +# Node 1: NATS/ETCD + first frontend +# Node 2+: aggregated workers + optional additional frontends + +NGINX_NODE=${nodes[0]} +{% endraw %} +{% if total_nodes > 1 %} +{% raw %} +MASTER_NODE=${nodes[1]} +{% endraw %} +{% else %} +{% raw %} +MASTER_NODE=${nodes[0]} +{% endraw %} +{% endif %} +{% raw %} +MASTER_IP=$(get_node_ip "$MASTER_NODE" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") +if [ -z "$MASTER_IP" ]; then + echo "Error: Could not retrieve IP address for master host $MASTER_NODE" + exit 1 +fi +echo "Master IP address (node 1): $MASTER_IP" +echo "Nginx node (node 0): $NGINX_NODE" + +# Generate frontend IP list for nginx config +frontend_hosts=() +frontend_ips=() +# Node 1 always has a frontend (with NATS/ETCD) +frontend_hosts+=("$MASTER_NODE") +frontend_ips+=("$MASTER_IP") + +# Add additional frontends based on num_additional_frontends +{% endraw %}ADDITIONAL_FRONTENDS={{ num_additional_frontends }}{% raw %} +if [ "$ADDITIONAL_FRONTENDS" -gt 0 ]; then + # Calculate which nodes get additional frontends + # We have AGG_NODES aggregated worker nodes, distribute additional frontends across them + nodes_per_frontend=$(( (AGG_NODES - 1 + ADDITIONAL_FRONTENDS - 1) / ADDITIONAL_FRONTENDS )) # ceil division + frontend_node_idx=2 # Start from node 2 (node 1 already has frontend) + + for i in $(seq 1 $ADDITIONAL_FRONTENDS); do + if [ $frontend_node_idx -lt $TOTAL_NODES ]; then + node_name=${nodes[$frontend_node_idx]} + node_ip=$(get_node_ip "$node_name" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + frontend_hosts+=("$node_name") + frontend_ips+=("$node_ip") + echo "Additional frontend $i on node $frontend_node_idx: $node_name ($node_ip)" + frontend_node_idx=$((frontend_node_idx + nodes_per_frontend)) + fi + done +fi + +echo "Frontend hosts: ${frontend_hosts[@]}" +echo "Frontend IPs: ${frontend_ips[@]}" + +{% endraw %} +{% if total_nodes > 1 %} +{% raw %} +# Generate nginx configuration +# Build a Python list literal of frontend hosts from the bash array +FRONTEND_LIST=$(printf "'%s'," "${frontend_ips[@]}") +FRONTEND_LIST="[${FRONTEND_LIST%,}]" +export FRONTEND_LIST SCRIPT_DIR LOG_DIR +python3 - <<'PY' +import os +from jinja2 import Template + +template_path = os.path.join(os.environ['SCRIPT_DIR'], 'templates/nginx.conf.j2') +output_path = os.path.join(os.environ['LOG_DIR'], 'nginx.conf') + +with open(template_path, 'r') as f: + tmpl = Template(f.read()) + +frontend_hosts = eval(os.environ['FRONTEND_LIST']) +config = tmpl.render(frontend_hosts=frontend_hosts) + +with open(output_path, 'w') as f: + f.write(config) +PY +{% endraw %} +{% endif %} +{% raw %} + +{% endraw %} +{% else %} +{% raw %} +# Traditional architecture - first aggregated worker node handles everything +MASTER_IP=$(get_node_ip "${nodes[0]}" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") +if [ -z "$MASTER_IP" ]; then + echo "Error: Could not retrieve IP address for master host ${nodes[0]}" + exit 1 +fi +echo "Master IP address: $MASTER_IP" +{% endraw %} +{% endif %} +{% raw %} + +# Compute leader nodes for each aggregated worker +{% endraw %} +{% if enable_multiple_frontends %} +{% raw %} +# With multiple frontends: keep offset 0; nginx coexists on node 0 +WORKER_NODE_OFFSET=0 +{% endraw %} +{% else %} +{% raw %} +# Traditional: workers start from node 0 +WORKER_NODE_OFFSET=0 +{% endraw %} +{% endif %} +{% raw %} + +agg_leaders=() +for i in $(seq 0 $((AGG_WORKERS - 1))); do + leader_idx=$((WORKER_NODE_OFFSET + i * AGG_NODES_PER_WORKER)) + agg_leaders[$i]=$leader_idx +done + +echo "Aggregated worker leaders: ${agg_leaders[@]}" + +# Prepare enroot arguments to pass to srun commands +ENROOT_ARGS="\ + --container-image=${CONTAINER_IMAGE} \ + --no-container-entrypoint \ + --no-container-mount-home \ + --container-mounts=${MODEL_DIR}:/model/,${CONFIG_DIR}:/configs/,${SCRIPT_DIR}:/scripts/,${OUTPUT_DIR}:/outputs/,${LOG_DIR}:/logs/{% endraw %}{% if extra_container_mounts %},{{ extra_container_mounts }}{% endif %}{% raw %} \ +" + +# Build common worker arguments +{% raw %} +WORKER_ARGS="--gpu_type ${GPU_TYPE} --gpus_per_node ${GPUS_PER_NODE} --master_ip ${MASTER_IP}" +{% endraw %} +{% if enable_multiple_frontends %} +{% raw %} +# Add multiple frontends flag for worker setup +WORKER_ARGS="$WORKER_ARGS --multiple-frontends-enabled" +{% endraw %} +{% endif %} +{% raw %} +# Set profiler mode from config +WORKER_ARGS="$WORKER_ARGS --profiler {% endraw %}{{ profiler }}{% raw %}" +{% endraw %} +{% raw %} +# Add SGLang config path (mounted in container at /logs/) +WORKER_ARGS="$WORKER_ARGS --sglang-config-path /logs/sglang_config.yaml" +{% endraw %} +{% if setup_script %} +# Add custom setup script if provided +WORKER_ARGS="$WORKER_ARGS --setup-script {{ setup_script }}" +{% endif %} +{% raw %} + +{% endraw %} +{% if enable_multiple_frontends %} +{% raw %} +{% endraw %} +{% if total_nodes > 1 %} +{% raw %} +# Launch nginx on node 0 +echo "Launching nginx on ${NGINX_NODE}" +cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$NGINX_NODE --output=${LOG_DIR}/${NGINX_NODE}_nginx.out python /scripts/worker_setup.py --worker_type nginx --nginx_config /logs/nginx.conf ${WORKER_ARGS}" +echo "$cmd" +$cmd & +{% endraw %} +{% endif %} +{% raw %} + +# Launch frontend on master node (node 1) - this will also start NATS/ETCD +echo "Launching frontend + NATS/ETCD on master node ${MASTER_NODE}" +cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$MASTER_NODE --output=${LOG_DIR}/${MASTER_NODE}_frontend_0.out python /scripts/worker_setup.py --worker_type frontend --worker_idx 0 ${WORKER_ARGS}" +echo "$cmd" +$cmd & + +# Launch additional frontends on designated nodes +if [ "$ADDITIONAL_FRONTENDS" -gt 0 ]; then + frontend_idx=1 # Start from 1 since node 1 is frontend 0 + nodes_per_frontend=$(( (TOTAL_NODES - 2 + ADDITIONAL_FRONTENDS - 1) / ADDITIONAL_FRONTENDS )) + frontend_node_idx=2 + + for i in $(seq 1 $ADDITIONAL_FRONTENDS); do + if [ $frontend_node_idx -lt $TOTAL_NODES ]; then + node=${nodes[$frontend_node_idx]} + echo "Launching additional frontend $frontend_idx on node $frontend_node_idx: $node" + cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_frontend_${frontend_idx}.out python /scripts/worker_setup.py --worker_type frontend --worker_idx ${frontend_idx} ${WORKER_ARGS}" + echo "$cmd" + $cmd & + frontend_idx=$((frontend_idx + 1)) + frontend_node_idx=$((frontend_node_idx + nodes_per_frontend)) + fi + done +fi +{% endraw %} +{% else %} +{% raw %} +# Traditional: first aggregated worker node also runs frontend + NATS/ETCD +# This is handled in setup_aggregated_worker when worker_idx=0 and local_rank=0 +{% endraw %} +{% endif %} +{% raw %} + +# Launch aggregated workers +for worker_idx in $(seq 0 $((AGG_WORKERS - 1))); do + leader_idx=${agg_leaders[$worker_idx]} + leader_node=${nodes[$leader_idx]} + + # Get leader IP for this worker group + LEADER_IP=$(get_node_ip "$leader_node" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + echo "Aggregated worker $worker_idx leader: $leader_node ($LEADER_IP)" + + # Launch all nodes for this worker + for node_idx in $(seq 0 $((AGG_NODES_PER_WORKER - 1))); do + global_node_idx=$((leader_idx + node_idx)) + node=${nodes[$global_node_idx]} + local_rank=$node_idx + + echo "Launching aggregated worker $worker_idx, node $global_node_idx (local_rank $local_rank): $node" +{% endraw %} +{% if enable_config_dump %} +{% raw %} + CONFIG_DUMP_ARG="--dump-config-path /logs/${node}_config.json" +{% endraw %} +{% else %} +{% raw %} + CONFIG_DUMP_ARG="" +{% endraw %} +{% endif %} +{% raw %} + cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_agg_w${worker_idx}.out python /scripts/worker_setup.py --leader_ip ${LEADER_IP} --worker_idx ${worker_idx} --local_rank ${local_rank} --nodes_per_worker ${AGG_NODES_PER_WORKER} --worker_type aggregated ${CONFIG_DUMP_ARG} ${WORKER_ARGS}" + echo "$cmd" + $cmd & + done +done + +echo "" +{% endraw %} +{% if enable_multiple_frontends %} +{% raw %} +echo "Frontend available at: http://${NGINX_NODE}:8000" +{% endraw %} +{% if total_nodes > 1 %} +{% raw %} +echo "To connect to the nginx node:" +echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${NGINX_NODE} --overlap --pty bash" +{% endraw %} +{% else %} +{% raw %} +echo "To connect to the master node:" +echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${NGINX_NODE} --overlap --pty bash" +{% endraw %} +{% endif %} +{% raw %} +echo "To connect to the master node (NATS/ETCD):" +echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${MASTER_NODE} --overlap --pty bash" +{% endraw %} +{% else %} +{% raw %} +echo "To connect to the master node:" +echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --overlap --pty bash" +{% endraw %} +{% endif %} + +{% if do_benchmark %} +{% raw %} +BENCHMARK_TYPE={% endraw %}{{ benchmark_type }}{% raw %} +BENCHMARK_ARGS="{% endraw %}{{ benchmark_arg }}{% raw %}" +srun --nodes=1 --ntasks=1 $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --output=${LOG_DIR}/benchmark.out --overlap bash /scripts/benchmarks/${BENCHMARK_TYPE}/bench.sh $AGG_WORKERS 0 0 $DECODE_GPUS ${BENCHMARK_ARGS} & +{% endraw %} +{% endif %} + +{% if profiler != 'none' %} + +{% raw %} +# Torch profiling mode for aggregated workers +echo "Starting torch profiling on aggregated worker..." + +# Get leader node for first aggregated worker +AGG_LEADER_NODE=${nodes[${agg_leaders[0]}]} + +echo "Aggregated profiling will run on: $AGG_LEADER_NODE" + +# Build aggregated leader IP list (comma-separated) for profiling +agg_leader_ips=() +for worker_idx in $(seq 0 $((AGG_WORKERS - 1))); do + leader_idx=${agg_leaders[$worker_idx]} + leader_node=${nodes[$leader_idx]} + leader_ip=$(get_node_ip "$leader_node" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + if [[ -z "$leader_ip" ]]; then + echo "Error: failed to get leader IP for aggregated worker $worker_idx ($leader_node)" + exit 1 + fi + agg_leader_ips+=("$leader_ip") +done +PROFILE_AGG_IPS=$(IFS=','; echo "${agg_leader_ips[*]}") + +# Run profiling on first aggregated worker's leader node +# Use "decode" mode for aggregated since it profiles the full generation pipeline +srun --nodes=1 --ntasks=1 $ENROOT_ARGS --jobid $SLURM_JOB_ID -w $AGG_LEADER_NODE \ + --output=${LOG_DIR}/profile_aggregated.out --overlap \ + bash -c "{% endraw %}{% if profiler == 'torch' %}SGLANG_TORCH_PROFILER_DIR=/logs/profiles/aggregated {% endif %}{{ profiling_driver_env }} {{ aggregated_profile_env }}{% raw %} PROFILE_AGG_IPS=${PROFILE_AGG_IPS} HEAD_NODE=${agg_leader_ips[0]} HEAD_PORT=30000 /scripts/profiling/profile.sh 0 $AGG_WORKERS 0 $DECODE_GPUS $TOTAL_GPUS" & +{% endraw %} +{% endif %} + +{% if profiler != 'none' %} +{% raw %} +# Wait for profiling script to complete +echo "Waiting for profiling script to complete..." +wait +exit_code=$? +echo "Profiling script finished at $(date) with exit code ${exit_code}" +exit $exit_code +{% endraw %} +{% else %} +{% raw %} +# Wait for first task to complete +wait -n +first_exit_code=$? +echo "Script finished at $(date) with exit code ${first_exit_code}" +exit $first_exit_code +{% endraw %} +{% endif %} + +echo "" +echo "Make sure to cancel the job at the end:" +echo "scancel $SLURM_JOB_ID" + diff --git a/scripts/templates/job_script_template_disagg.j2 b/scripts/templates/job_script_template_disagg.j2 new file mode 100755 index 00000000..b6c047db --- /dev/null +++ b/scripts/templates/job_script_template_disagg.j2 @@ -0,0 +1,564 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }} +#SBATCH --nodes={{ total_nodes }} +#SBATCH --ntasks={{ total_nodes }} +{% if use_segment_sbatch_directive %} +#SBATCH --segment={{ total_nodes }} +{% endif %} +#SBATCH --ntasks-per-node=1 +{% if use_gpus_per_node_directive %} +#SBATCH --gpus-per-node={{ gpus_per_node }} +{% endif %} +#SBATCH --account={{ account }} +#SBATCH --time={{ time_limit }} +#SBATCH --output={{ log_dir_prefix }}/%j_{{ prefill_workers }}P_{{ decode_workers }}D_{{ timestamp }}/log.out +#SBATCH --partition={{ partition }} + +# Constants +set -x +PREFILL_NODES={{ prefill_nodes }} +DECODE_NODES={{ decode_nodes }} +PREFILL_WORKERS={{ prefill_workers }} +DECODE_WORKERS={{ decode_workers }} +TOTAL_NODES=$((PREFILL_NODES + DECODE_NODES)) +GPUS_PER_NODE={{ gpus_per_node }} +TOTAL_GPUS=$((TOTAL_NODES * GPUS_PER_NODE)) +PREFILL_GPUS=$((PREFILL_NODES * GPUS_PER_NODE)) +DECODE_GPUS=$((DECODE_NODES * GPUS_PER_NODE)) +PREFILL_NODES_PER_WORKER=$((PREFILL_NODES / PREFILL_WORKERS)) +DECODE_NODES_PER_WORKER=$((DECODE_NODES / DECODE_WORKERS)) +{% if log_dir_prefix.startswith('/') %} +LOG_DIR="{{ log_dir_prefix }}/${SLURM_JOB_ID}_{{ prefill_workers }}P_{{ decode_workers }}D_{{ timestamp }}" +{% else %} +LOG_DIR="${SLURM_SUBMIT_DIR}/{{ log_dir_prefix }}/${SLURM_JOB_ID}_{{ prefill_workers }}P_{{ decode_workers }}D_{{ timestamp }}" +{% endif %} +SCRIPT_DIR="${SLURM_SUBMIT_DIR}/scripts" +OUTPUT_DIR="${SLURM_SUBMIT_DIR}/outputs" +MODEL_DIR="{{ model_dir }}" +CONFIG_DIR="{{ config_dir }}" +CONTAINER_IMAGE="{{ container_image }}" +NETWORK_INTERFACE="{{ network_interface }}" +GPU_TYPE="{{ gpu_type | default('h100') }}" +set +x + +{% raw %} + +mkdir -p "${OUTPUT_DIR}" "${LOG_DIR}" + +# Source utility functions for robust IP discovery +source "${SCRIPT_DIR}/utils/slurm_utils.sh" + +nodes=($(scontrol show hostnames $SLURM_NODELIST)) +if [ ${#nodes[@]} -ne $TOTAL_NODES ]; then + echo "Error: Expected $TOTAL_NODES nodes but got ${#nodes[@]} nodes" + exit 1 +fi + +# Print node information +for i in "${!nodes[@]}"; do + echo "Node $i: ${nodes[$i]}" +done + +{% endraw %} +{% if enable_multiple_frontends and not use_sglang_router %} +{% raw %} +# Multiple frontend architecture +# Node 0: nginx only + prefill shard +# Node 1: NATS/ETCD + first frontend + prefill shard +# Node 2+: prefill/decode workers + optional additional frontends + +NGINX_NODE=${nodes[0]} +MASTER_NODE=${nodes[1]} +MASTER_IP=$(get_node_ip "$MASTER_NODE" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") +if [ -z "$MASTER_IP" ]; then + echo "Error: Could not retrieve IP address for master host $MASTER_NODE" + exit 1 +fi +echo "Master IP address (node 1): $MASTER_IP" +echo "Nginx node (node 0): $NGINX_NODE" + +# Generate frontend IP list for nginx config +frontend_hosts=() +frontend_ips=() +# Node 1 always has a frontend (with NATS/ETCD) +frontend_hosts+=("$MASTER_NODE") +frontend_ips+=("$MASTER_IP") + +# Add additional frontends based on num_additional_frontends +{% endraw %}ADDITIONAL_FRONTENDS={{ num_additional_frontends }}{% raw %} +if [ "$ADDITIONAL_FRONTENDS" -gt 0 ]; then + # Calculate which nodes get additional frontends + # We have TOTAL_NODES prefill/decode nodes, distribute additional frontends across them + nodes_per_frontend=$(( (TOTAL_NODES - 1 + ADDITIONAL_FRONTENDS - 1) / ADDITIONAL_FRONTENDS )) # ceil division + frontend_node_idx=2 # Start from node 2 (node 1 already has frontend) + + for i in $(seq 1 $ADDITIONAL_FRONTENDS); do + if [ $frontend_node_idx -lt $TOTAL_NODES ]; then + node_name=${nodes[$frontend_node_idx]} + node_ip=$(get_node_ip "$node_name" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + frontend_hosts+=("$node_name") + frontend_ips+=("$node_ip") + echo "Additional frontend $i on node $frontend_node_idx: $node_name ($node_ip)" + frontend_node_idx=$((frontend_node_idx + nodes_per_frontend)) + fi + done +fi + +echo "Frontend hosts: ${frontend_hosts[@]}" +echo "Frontend IPs: ${frontend_ips[@]}" + +# Generate nginx configuration +# Build a Python list literal of frontend hosts from the bash array +FRONTEND_LIST=$(printf "'%s'," "${frontend_ips[@]}") +FRONTEND_LIST="[${FRONTEND_LIST%,}]" +export FRONTEND_LIST SCRIPT_DIR LOG_DIR +python3 - <<'PY' +import os +from jinja2 import Template + +template_path = os.path.join(os.environ['SCRIPT_DIR'], 'templates/nginx.conf.j2') +output_path = os.path.join(os.environ['LOG_DIR'], 'nginx.conf') + +with open(template_path, 'r') as f: + tmpl = Template(f.read()) + +frontend_hosts = eval(os.environ['FRONTEND_LIST']) +config = tmpl.render(frontend_hosts=frontend_hosts) + +with open(output_path, 'w') as f: + f.write(config) +PY + +{% endraw %} +{% else %} +{% raw %} +# Traditional architecture - first prefill node handles everything +MASTER_IP=$(get_node_ip "${nodes[0]}" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") +if [ -z "$MASTER_IP" ]; then + echo "Error: Could not retrieve IP address for master host ${nodes[0]}" + exit 1 +fi +echo "Master IP address: $MASTER_IP" +{% endraw %} +{% endif %} +{% raw %} + +# Compute leader nodes for each worker +{% endraw %} +{% if enable_multiple_frontends and not use_sglang_router %} +{% raw %} +# With multiple frontends: keep offset 0; nginx coexists on node 0 +WORKER_NODE_OFFSET=0 +{% endraw %} +{% else %} +{% raw %} +# Traditional: workers start from node 0 +WORKER_NODE_OFFSET=0 +{% endraw %} +{% endif %} +{% raw %} + +prefill_leaders=() +for i in $(seq 0 $((PREFILL_WORKERS - 1))); do + leader_idx=$((WORKER_NODE_OFFSET + i * PREFILL_NODES_PER_WORKER)) + prefill_leaders[$i]=$leader_idx +done + +decode_leaders=() +for i in $(seq 0 $((DECODE_WORKERS - 1))); do + leader_idx=$((WORKER_NODE_OFFSET + PREFILL_NODES + i * DECODE_NODES_PER_WORKER)) + decode_leaders[$i]=$leader_idx +done + +echo "Prefill worker leaders: ${prefill_leaders[@]}" +echo "Decode worker leaders: ${decode_leaders[@]}" + +# Prepare enroot arguments to pass to srun commands +ENROOT_ARGS="\ + --container-image=${CONTAINER_IMAGE} \ + --no-container-entrypoint \ + --no-container-mount-home \ + --container-mounts=${MODEL_DIR}:/model/,${CONFIG_DIR}:/configs/,${SCRIPT_DIR}:/scripts/,${OUTPUT_DIR}:/outputs/,${LOG_DIR}:/logs/{% endraw %}{% if sglang_src_dir %},{{ sglang_src_dir }}:/ext-sglang-src/{% endif %}{% if extra_container_mounts %},{{ extra_container_mounts }}{% endif %}{% raw %} \ +" +{% endraw %} + +# Build common worker arguments +{% raw %} +WORKER_ARGS="--gpu_type ${GPU_TYPE} --gpus_per_node ${GPUS_PER_NODE} --master_ip ${MASTER_IP}" +{% endraw %} +{% if use_sglang_router %} +{% raw %} +WORKER_ARGS="$WORKER_ARGS --use-sglang-router" +{% endraw %} +{% endif %} +{% if enable_multiple_frontends and not use_sglang_router %} +{% raw %} +# Add multiple frontends flag for worker setup +WORKER_ARGS="$WORKER_ARGS --multiple-frontends-enabled" +{% endraw %} +{% endif %} +{% raw %} +# Set profiler mode from config +WORKER_ARGS="$WORKER_ARGS --profiler {% endraw %}{{ profiler }}{% raw %}" +{% endraw %} +{% raw %} +# Add SGLang config path (mounted in container at /logs/) +WORKER_ARGS="$WORKER_ARGS --sglang-config-path /logs/sglang_config.yaml" +{% endraw %} +{% if setup_script %} +# Add custom setup script if provided +WORKER_ARGS="$WORKER_ARGS --setup-script {{ setup_script }}" +{% endif %} +{% raw %} + +{% endraw %} +{% if enable_multiple_frontends and not use_sglang_router %} +{% raw %} +{% endraw %} +{% if total_nodes > 1 %} +{% raw %} +# Launch nginx on node 0 +echo "Launching nginx on ${NGINX_NODE}" +cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$NGINX_NODE --output=${LOG_DIR}/${NGINX_NODE}_nginx.out python /scripts/worker_setup.py --worker_type nginx --nginx_config /logs/nginx.conf ${WORKER_ARGS}" +echo "$cmd" +$cmd & +{% endraw %} +{% endif %} +{% raw %} + +# Launch frontend on master node (node 1) - this will also start NATS/ETCD +echo "Launching frontend + NATS/ETCD on master node ${MASTER_NODE}" +cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$MASTER_NODE --output=${LOG_DIR}/${MASTER_NODE}_frontend_0.out python /scripts/worker_setup.py --worker_type frontend --worker_idx 0 ${WORKER_ARGS}" +echo "$cmd" +$cmd & + +# Launch additional frontends on designated nodes +if [ "$ADDITIONAL_FRONTENDS" -gt 0 ]; then + frontend_idx=1 # Start from 1 since node 1 is frontend 0 + nodes_per_frontend=$(( (TOTAL_NODES - 2 + ADDITIONAL_FRONTENDS - 1) / ADDITIONAL_FRONTENDS )) + frontend_node_idx=2 + + for i in $(seq 1 $ADDITIONAL_FRONTENDS); do + if [ $frontend_node_idx -lt $TOTAL_NODES ]; then + node=${nodes[$frontend_node_idx]} + echo "Launching additional frontend $frontend_idx on node $frontend_node_idx: $node" + cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_frontend_${frontend_idx}.out python /scripts/worker_setup.py --worker_type frontend --worker_idx ${frontend_idx} ${WORKER_ARGS}" + echo "$cmd" + $cmd & + frontend_idx=$((frontend_idx + 1)) + frontend_node_idx=$((frontend_node_idx + nodes_per_frontend)) + fi + done +fi +{% endraw %} +{% endif %} +{% raw %} + +# Launch prefill workers +for worker_idx in $(seq 0 $((PREFILL_WORKERS - 1))); do + leader_idx=${prefill_leaders[$worker_idx]} + leader_node=${nodes[$leader_idx]} + + # Get leader IP for this worker group + LEADER_IP=$(get_node_ip "$leader_node" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + echo "Prefill worker $worker_idx leader: $leader_node ($LEADER_IP)" + + # Launch all nodes for this worker + for node_idx in $(seq 0 $((PREFILL_NODES_PER_WORKER - 1))); do + global_node_idx=$((leader_idx + node_idx)) + node=${nodes[$global_node_idx]} + local_rank=$node_idx + + echo "Launching prefill worker $worker_idx, node $global_node_idx (local_rank $local_rank): $node" +{% endraw %} +{% if enable_config_dump %} +{% raw %} + CONFIG_DUMP_ARG="--dump-config-path /logs/${node}_config.json" +{% endraw %} +{% else %} +{% raw %} + CONFIG_DUMP_ARG="" +{% endraw %} +{% endif %} +{% raw %} + cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_prefill_w${worker_idx}.out python /scripts/worker_setup.py --leader_ip ${LEADER_IP} --worker_idx ${worker_idx} --local_rank ${local_rank} --nodes_per_worker ${PREFILL_NODES_PER_WORKER} --worker_type prefill ${WORKER_ARGS} ${CONFIG_DUMP_ARG}" + echo "$cmd" + $cmd & + done +done + +# Launch decode workers +for worker_idx in $(seq 0 $((DECODE_WORKERS - 1))); do + leader_idx=${decode_leaders[$worker_idx]} + leader_node=${nodes[$leader_idx]} + + # Get leader IP for this worker group + LEADER_IP=$(get_node_ip "$leader_node" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + echo "Decode worker $worker_idx leader: $leader_node ($LEADER_IP)" + + # Launch all nodes for this worker + for node_idx in $(seq 0 $((DECODE_NODES_PER_WORKER - 1))); do + global_node_idx=$((leader_idx + node_idx)) + node=${nodes[$global_node_idx]} + local_rank=$node_idx + + echo "Launching decode worker $worker_idx, node $global_node_idx (local_rank $local_rank): $node" +{% endraw %} +{% if enable_config_dump %} +{% raw %} + CONFIG_DUMP_ARG="--dump-config-path /logs/${node}_config.json" +{% endraw %} +{% else %} +{% raw %} + CONFIG_DUMP_ARG="" +{% endraw %} +{% endif %} +{% raw %} + cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_decode_w${worker_idx}.out python /scripts/worker_setup.py --leader_ip ${LEADER_IP} --worker_idx ${worker_idx} --local_rank ${local_rank} --nodes_per_worker ${DECODE_NODES_PER_WORKER} --worker_type decode ${CONFIG_DUMP_ARG} ${WORKER_ARGS}" + echo "$cmd" + $cmd & + done +done + +echo "" +{% endraw %} +{% if enable_multiple_frontends and not use_sglang_router %} +{% raw %} +echo "Frontend available at: http://${NGINX_NODE}:8000" +echo "To connect to the nginx node:" +echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${NGINX_NODE} --overlap --pty bash" +echo "To connect to the master node (NATS/ETCD):" +echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${MASTER_NODE} --overlap --pty bash" +{% endraw %} +{% else %} +{% raw %} +echo "To connect to the host prefill node:" +echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --overlap --pty bash" +{% endraw %} +{% endif %} +{% raw %} + +# Launch sglang router(s) when enabled +{% endraw %}{% if use_sglang_router %}{% raw %} +# Wait for prefill server to be ready before launching router +echo "Waiting for prefill server to be ready on ${nodes[${prefill_leaders[0]}]}..." +PREFILL_LEADER_NODE=${nodes[${prefill_leaders[0]}]} +PREFILL_LEADER_IP=$(get_node_ip "$PREFILL_LEADER_NODE" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") +MAX_WAIT=600 +WAIT_COUNT=0 +while [ $WAIT_COUNT -lt $MAX_WAIT ]; do + status_code=$(srun --nodes=1 --ntasks=1 $ENROOT_ARGS --jobid $SLURM_JOB_ID -w $PREFILL_LEADER_NODE --overlap bash -c "curl -s -o /dev/null -w '%{http_code}' http://${PREFILL_LEADER_IP}:30000/health" 2>/dev/null || echo "000") + if [ "$status_code" -eq 200 ]; then + echo "Prefill server is ready at http://${PREFILL_LEADER_IP}:30000" + break + fi + echo "Prefill server not ready yet (status: ${status_code}), waiting... ($WAIT_COUNT/$MAX_WAIT)" + sleep 30 + WAIT_COUNT=$((WAIT_COUNT + 30)) +done + +if [ $WAIT_COUNT -ge $MAX_WAIT ]; then + echo "Warning: Prefill server did not become ready within ${MAX_WAIT}s, proceeding anyway..." +fi +# Collect leader IPs for prefill and decode +PREFILL_LEADER_IPS=() +for idx in "${prefill_leaders[@]}"; do + node_name=${nodes[$idx]} + ip=$(get_node_ip "$node_name" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + PREFILL_LEADER_IPS+=("$ip") +done +DECODE_LEADER_IPS=() +for idx in "${decode_leaders[@]}"; do + node_name=${nodes[$idx]} + ip=$(get_node_ip "$node_name" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + DECODE_LEADER_IPS+=("$ip") +done + +PREFILL_IPS_STR=$(IFS=,; echo "${PREFILL_LEADER_IPS[*]}") +DECODE_IPS_STR=$(IFS=,; echo "${DECODE_LEADER_IPS[*]}") + +{% endraw %} +{% if enable_multiple_frontends %} +{% raw %} +# Multiple router architecture (mirrors dynamo frontend scaling) +# Node 0: nginx load balancer + first router +# Node 1+: additional routers distributed across worker nodes + +NGINX_NODE=${nodes[0]} +NGINX_IP=$(get_node_ip "$NGINX_NODE" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + +# Build router host/IP lists +router_hosts=() +router_ips=() + +# First router always on node 0 +router_hosts+=("$NGINX_NODE") +router_ips+=("$NGINX_IP") + +# Add additional routers (uses same num_additional_frontends setting as dynamo) +{% endraw %}ADDITIONAL_ROUTERS={{ num_additional_frontends }}{% raw %} +if [ "$ADDITIONAL_ROUTERS" -gt 0 ]; then + # Calculate which nodes get additional routers + # Distribute additional routers across nodes, starting from node 1 + nodes_per_router=$(( (TOTAL_NODES - 1 + ADDITIONAL_ROUTERS - 1) / ADDITIONAL_ROUTERS )) # ceil division + router_node_idx=1 # Start from node 1 (node 0 already has first router) + + for i in $(seq 1 $ADDITIONAL_ROUTERS); do + if [ $router_node_idx -lt $TOTAL_NODES ]; then + node_name=${nodes[$router_node_idx]} + node_ip=$(get_node_ip "$node_name" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + router_hosts+=("$node_name") + router_ips+=("$node_ip") + echo "Additional router $i on node $router_node_idx: $node_name ($node_ip)" + router_node_idx=$((router_node_idx + nodes_per_router)) + fi + done +fi + +echo "Router hosts: ${router_hosts[@]}" +echo "Router IPs: ${router_ips[@]}" + +# Generate nginx configuration for router load balancing +# Routers use internal port 30080, nginx exposes on 8000 +ROUTER_INTERNAL_PORT=30080 +ROUTER_LIST=$(printf "'%s'," "${router_ips[@]}") +ROUTER_LIST="[${ROUTER_LIST%,}]" +export ROUTER_LIST ROUTER_INTERNAL_PORT SCRIPT_DIR LOG_DIR +python3 - <<'PY' +import os +from jinja2 import Template + +template_path = os.path.join(os.environ['SCRIPT_DIR'], 'templates/nginx.conf.j2') +output_path = os.path.join(os.environ['LOG_DIR'], 'nginx.conf') + +with open(template_path, 'r') as f: + tmpl = Template(f.read()) + +router_hosts = eval(os.environ['ROUTER_LIST']) +backend_port = int(os.environ['ROUTER_INTERNAL_PORT']) +config = tmpl.render(frontend_hosts=router_hosts, backend_port=backend_port) + +with open(output_path, 'w') as f: + f.write(config) +PY + +# Launch nginx on node 0 +echo "Launching nginx for router load balancing on ${NGINX_NODE}" +cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$NGINX_NODE --output=${LOG_DIR}/${NGINX_NODE}_nginx.out python /scripts/worker_setup.py --worker_type nginx --nginx_config /logs/nginx.conf ${WORKER_ARGS}" +echo "$cmd" +$cmd & + +# Launch first router on node 0 (with nginx) +# Router listens on internal port, nginx proxies from 8000 +echo "Launching sglang-router 0 on ${NGINX_NODE} (internal port ${ROUTER_INTERNAL_PORT})" +cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$NGINX_NODE --output=${LOG_DIR}/${NGINX_NODE}_router_0.out python /scripts/worker_setup.py --worker_type sglang-router --worker_idx 0 --prefill-ips ${PREFILL_IPS_STR} --decode-ips ${DECODE_IPS_STR} --router-port ${ROUTER_INTERNAL_PORT} ${WORKER_ARGS}" +echo "$cmd" +$cmd & + +# Launch additional routers on designated nodes +if [ "$ADDITIONAL_ROUTERS" -gt 0 ]; then + router_idx=1 + nodes_per_router=$(( (TOTAL_NODES - 1 + ADDITIONAL_ROUTERS - 1) / ADDITIONAL_ROUTERS )) + router_node_idx=1 + + for i in $(seq 1 $ADDITIONAL_ROUTERS); do + if [ $router_node_idx -lt $TOTAL_NODES ]; then + node=${nodes[$router_node_idx]} + echo "Launching sglang-router $router_idx on node $router_node_idx: $node (internal port ${ROUTER_INTERNAL_PORT})" + cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_router_${router_idx}.out python /scripts/worker_setup.py --worker_type sglang-router --worker_idx ${router_idx} --prefill-ips ${PREFILL_IPS_STR} --decode-ips ${DECODE_IPS_STR} --router-port ${ROUTER_INTERNAL_PORT} ${WORKER_ARGS}" + echo "$cmd" + $cmd & + router_idx=$((router_idx + 1)) + router_node_idx=$((router_node_idx + nodes_per_router)) + fi + done +fi + +TOTAL_ROUTERS=$((1 + ADDITIONAL_ROUTERS)) +echo "Frontend available at: http://${NGINX_NODE}:8000 (nginx load balancing ${TOTAL_ROUTERS} sglang-routers)" +{% endraw %} +{% else %} +{% raw %} +# Single router architecture - no nginx, router directly on node 0 port 8000 +ROUTER_NODE=${nodes[0]} +echo "Launching single sglang-router on ${ROUTER_NODE} (port 8000)" +cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$ROUTER_NODE --output=${LOG_DIR}/${ROUTER_NODE}_router.out python /scripts/worker_setup.py --worker_type sglang-router --worker_idx 0 --prefill-ips ${PREFILL_IPS_STR} --decode-ips ${DECODE_IPS_STR} --router-port 8000 ${WORKER_ARGS}" +echo "$cmd" +$cmd & + +echo "Frontend available at: http://${ROUTER_NODE}:8000" +{% endraw %} +{% endif %} +{% endif %} +{% raw %} + +echo "" +echo "Make sure to cancel the job at the end:" +echo "scancel $SLURM_JOB_ID" + +# Instead of waiting for all tasks to complete, wait for benchmark to complete and then exit. + +{% endraw %} + +BENCHMARK_TYPE={{ benchmark_type }} +BENCHMARK_ARGS="{{ benchmark_arg }}" +USE_SGLANG_ROUTER={{ "true" if use_sglang_router else "false" }} + +{% if do_benchmark %} +{% raw %} +srun --nodes=1 --ntasks=1 $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --output=${LOG_DIR}/benchmark.out --overlap bash /scripts/benchmarks/${BENCHMARK_TYPE}/bench.sh $PREFILL_WORKERS $DECODE_WORKERS $PREFILL_GPUS $DECODE_GPUS ${BENCHMARK_ARGS} ${USE_SGLANG_ROUTER} & +{% endraw %} +{% endif %} + +{% if profiler != 'none' %} +{% raw %} +# Torch/NSYS profiling mode: run a single orchestrator that profiles all prefill and decode workers. +echo "Starting unified profiler..." + +# Collect leader IPs for prefill and decode workers +PREFILL_LEADER_IPS=() +for idx in "${prefill_leaders[@]}"; do + node_name=${nodes[$idx]} + ip=$(get_node_ip "$node_name" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + PREFILL_LEADER_IPS+=("$ip") +done + +DECODE_LEADER_IPS=() +for idx in "${decode_leaders[@]}"; do + node_name=${nodes[$idx]} + ip=$(get_node_ip "$node_name" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + DECODE_LEADER_IPS+=("$ip") +done + +PREFILL_LEADER_IPS_STR=$(IFS=,; echo "${PREFILL_LEADER_IPS[*]}") +DECODE_LEADER_IPS_STR=$(IFS=,; echo "${DECODE_LEADER_IPS[*]}") + +# Use the first prefill leader as the orchestrator node +PROFILE_ORCHESTRATOR_NODE=${nodes[${prefill_leaders[0]}]} +echo "Unified profiling will run on orchestrator node: $PROFILE_ORCHESTRATOR_NODE" + +# Run a single profiling orchestrator that coordinates profiling across all leaders +srun --nodes=1 --ntasks=1 $ENROOT_ARGS --jobid $SLURM_JOB_ID -w $PROFILE_ORCHESTRATOR_NODE \ + --output=${LOG_DIR}/profile_all.out --overlap \ + bash -c "PROFILE_PREFILL_IPS=${PREFILL_LEADER_IPS_STR} PROFILE_DECODE_IPS=${DECODE_LEADER_IPS_STR} {% endraw %}{% if profiler == 'torch' %}SGLANG_TORCH_PROFILER_DIR=/logs/profiles {% endif %}{{ profiling_driver_env }} {{ prefill_profile_env }} {{ decode_profile_env }}{% raw %} /scripts/profiling/profile.sh $PREFILL_WORKERS $DECODE_WORKERS $PREFILL_GPUS $DECODE_GPUS $TOTAL_GPUS" & +{% endraw %} +{% endif %} + +{% if profiler != 'none' %} +{% raw %} +# Wait for all profiling scripts to complete (both prefill and decode) +echo "Waiting for all profiling scripts to complete..." +wait +exit_code=$? +echo "All profiling scripts finished at $(date) with exit code ${exit_code}" +exit $exit_code +{% endraw %} +{% else %} +{% raw %} +# Wait for first task (benchmark) to complete +wait -n +first_exit_code=$? +echo "Script finished at $(date) with exit code ${first_exit_code}" +exit $first_exit_code +{% endraw %} +{% endif %} diff --git a/scripts/worker_setup/command.py b/scripts/worker_setup/command.py new file mode 100644 index 00000000..71e4ab6f --- /dev/null +++ b/scripts/worker_setup/command.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Command building functions for SGLang workers.""" + +import logging +import os +import subprocess + + +def build_sglang_command_from_yaml( + worker_type: str, + worker_idx: int, + sglang_config_path: str, + host_ip: str, + port: int, + total_nodes: int, + rank: int, + profiler: str = "none", + dump_config_path: str | None = None, + use_sglang_router: bool = False, +) -> str: + """Build SGLang command using native YAML config support. + + dynamo.sglang supports reading config from YAML: + python3 -m dynamo.sglang --config file.yaml --config-key prefill + + sglang.launch_server (profiling mode or sglang router mode) requires explicit flags: + python3 -m sglang.launch_server --model-path /model/ --tp 4 ... + + Args: + worker_type: "prefill", "decode", or "aggregated" + sglang_config_path: Path to generated sglang_config.yaml + host_ip: Host IP for distributed coordination + port: Port for distributed coordination + total_nodes: Total number of nodes + rank: Node rank (0-indexed) + profiler: Profiling method: "none", "torch", or "nsys" + use_sglang_router: Use sglang.launch_server instead of dynamo.sglang + + Returns: + Full command string ready to execute + """ + import yaml + + # Load config to extract environment variables and mode config + with open(sglang_config_path) as f: + sglang_config = yaml.safe_load(f) + + config_key = worker_type if worker_type != "aggregated" else "aggregated" + + # Environment variables are stored at top level as {mode}_environment + env_key = f"{config_key}_environment" + env_vars = sglang_config.get(env_key, {}) + + # Build environment variable exports + env_exports = [] + for key, value in env_vars.items(): + env_exports.append(f"export {key}={value}") + if profiler == "torch": + env_exports.append(f"export SGLANG_TORCH_PROFILER_DIR=/logs/profiles/{config_key}") + + # Determine Python module based on profiling mode or sglang router mode + # Use sglang.launch_server when profiling OR when using sglang router (no dynamo) + use_launch_server = profiler != "none" or use_sglang_router + python_module = "sglang.launch_server" if use_launch_server else "dynamo.sglang" + nsys_prefix = f"nsys profile -t cuda,nvtx --cuda-graph-trace=node -c cudaProfilerApi --capture-range-end stop --force-overwrite true -o /logs/profiles/{config_key}_w{worker_idx}_{rank}" + + if use_launch_server: + # Profiling mode: inline all flags (sglang.launch_server doesn't support --config) + mode_config = sglang_config.get(config_key, {}) + # Wrap with NSYS on all ranks; outputs are isolated per-rank + if profiler == "nsys": + cmd_parts = [f"{nsys_prefix} python3 -m {python_module}"] + else: + cmd_parts = [f"python3 -m {python_module}"] + + # Add all SGLang flags from config + for key, value in sorted(mode_config.items()): + flag_name = key.replace("_", "-") + if isinstance(value, bool): + if value: + cmd_parts.append(f"--{flag_name}") + elif isinstance(value, list): + values_str = " ".join(str(v) for v in value) + cmd_parts.append(f"--{flag_name} {values_str}") + else: + cmd_parts.append(f"--{flag_name} {value}") + + # Add coordination flags + cmd_parts.extend( + [ + f"--dist-init-addr {host_ip}:{port}", + f"--nnodes {total_nodes}", + f"--node-rank {rank}", + "--host 0.0.0.0", + ] + ) + else: + # Normal mode: use --config and --config-key (dynamo.sglang supports this) + cmd_parts = [ + f"python3 -m {python_module}", + f"--config {sglang_config_path}", + f"--config-key {config_key}", + f"--dist-init-addr {host_ip}:{port}", + f"--nnodes {total_nodes}", + f"--node-rank {rank}", + "--host 0.0.0.0", + ] + + # Add dump-config-to flag if provided (not supported by sglang.launch_server; not used in aggregated mode) + if dump_config_path and not use_launch_server and worker_type != "aggregated": + cmd_parts.append(f"--dump-config-to {dump_config_path}") + + # Combine environment exports and command + full_command = " && ".join(env_exports + [" ".join(cmd_parts)]) if env_exports else " ".join(cmd_parts) + + return full_command + + +def install_dynamo_wheels(gpu_type: str) -> None: + """Install dynamo from PyPI. + + Args: + gpu_type: GPU type (unused - pip auto-selects correct architecture) + """ + logging.info("Installing dynamo 0.7.0 from PyPI") + + # Install ai-dynamo-runtime (pip auto-selects x86_64 or aarch64 wheel) + runtime_package = "ai-dynamo-runtime==0.7.0" + logging.info(f"Installing {runtime_package}") + result = subprocess.run(["python3", "-m", "pip", "install", runtime_package], capture_output=True, text=True) + if result.returncode != 0: + logging.error(f"Failed to install runtime package: {result.stderr}") + raise RuntimeError(f"Failed to install {runtime_package}") + + # Install ai-dynamo + dynamo_package = "ai-dynamo==0.7.0" + logging.info(f"Installing {dynamo_package}") + result = subprocess.run(["python3", "-m", "pip", "install", dynamo_package], capture_output=True, text=True) + if result.returncode != 0: + logging.error(f"Failed to install dynamo package: {result.stderr}") + raise RuntimeError(f"Failed to install {dynamo_package}") + + logging.info("Successfully installed dynamo from PyPI") + + +def install_sglang_from_source(sglang_src_path: str = "/ext-sglang-src") -> None: + """Install sglang from source for debugging in sglang-router mode. + + Skips installation silently if source directory is not mounted. + + Args: + sglang_src_path: Path to sglang source directory (default: /ext-sglang-src) + """ + if not os.path.exists(sglang_src_path): + logging.info(f"SGLang source not mounted at {sglang_src_path}, skipping source installation") + return + + # Verify the path is absolute and exists + abs_path = os.path.abspath(sglang_src_path) + # SGLang's Python package is in the 'python' subdirectory + abs_path = os.path.join(abs_path, 'python') + logging.info(f"Installing sglang from source: {abs_path}") + logging.info(f"Directory exists: {os.path.isdir(abs_path)}") + + # List directory contents for debugging + try: + contents = os.listdir(abs_path) + logging.info(f"Directory contains {len(contents)} items: {contents[:20]}") + except Exception as e: + logging.error(f"Cannot list directory {abs_path}: {e}") + raise RuntimeError(f"Cannot access sglang source directory: {abs_path}") + + # Check if this looks like a valid Python project + setup_py = os.path.join(abs_path, "setup.py") + pyproject_toml = os.path.join(abs_path, "pyproject.toml") + + if not os.path.exists(setup_py) and not os.path.exists(pyproject_toml): + logging.error(f"Directory {abs_path} does not contain setup.py or pyproject.toml") + logging.error(f"This does not appear to be a valid sglang source directory") + + raise RuntimeError( + f"Invalid sglang source directory: {abs_path}\n" + f"The directory must contain setup.py or pyproject.toml.\n" + f"Check that sglang_src_dir in your YAML points to the sglang project root.\n" + f"Found files: {contents[:20]}" + ) + + logging.info(f"Found Python project files in {abs_path}") + + # Change to sglang source directory and install in editable mode + result = subprocess.run( + ["python3", "-m", "pip", "install", "-e", ".", "--no-deps"], + cwd=abs_path, + capture_output=True, + text=True, + ) + if result.returncode != 0: + logging.error(f"Failed to install sglang from source: {result.stderr}") + raise RuntimeError(f"Failed to install sglang from {abs_path}") + + logging.info("Successfully installed sglang from source") + + +def get_gpu_command( + worker_type: str, + worker_idx: int, + sglang_config_path: str, + host_ip: str, + port: int, + total_nodes: int, + rank: int, + profiler: str = "none", + dump_config_path: str | None = None, + use_sglang_router: bool = False, +) -> str: + """Generate command to run SGLang worker using YAML config. + + Args: + worker_type: "prefill", "decode", or "aggregated" + sglang_config_path: Path to sglang_config.yaml + host_ip: Host IP for distributed coordination + port: Port for distributed coordination + total_nodes: Total number of nodes + rank: Node rank (0-indexed) + profiler: Profiling method: "none", "torch", or "nsys" + use_sglang_router: Use sglang.launch_server instead of dynamo.sglang + + Returns: + Command string to execute + """ + if not sglang_config_path or not os.path.exists(sglang_config_path): + raise ValueError(f"SGLang config path required but not found: {sglang_config_path}") + + logging.info(f"Building command from YAML config: {sglang_config_path}") + return build_sglang_command_from_yaml( + worker_type, worker_idx, sglang_config_path, host_ip, port, total_nodes, rank, profiler, dump_config_path, use_sglang_router + ) diff --git a/scripts/worker_setup/worker.py b/scripts/worker_setup/worker.py new file mode 100644 index 00000000..cdb5d25d --- /dev/null +++ b/scripts/worker_setup/worker.py @@ -0,0 +1,316 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Worker setup functions for prefill, decode, and aggregated workers.""" + +import logging +import subprocess +import os + +from .command import ( + get_gpu_command, + install_dynamo_wheels, + install_sglang_from_source, +) +from .environment import DIST_INIT_PORT, ETCD_CLIENT_PORT +from .infrastructure import setup_head_prefill_node +from .utils import run_command, wait_for_etcd + + +def _get_sglang_version() -> str | None: + """Get the installed sglang version.""" + try: + result = subprocess.run( + ["python", "-c", "import sglang; print(sglang.__version__)"], capture_output=True, text=True + ) + if result.returncode == 0: + return result.stdout.strip() + except Exception as e: + logging.warning(f"Failed to get sglang version: {e}") + return None + + +# TODO: this can be removed once we sync GB200 to > 0.5.5.post2 +def _patch_sglang_engine(): + """Temporary patch to fix send_to_rpc initialization. Only applies to 0.5.5.post2.""" + version = _get_sglang_version() + if version != "0.5.5.post2": + logging.info(f"Skipping patch - sglang version {version} != 0.5.5.post2") + return + + logging.info("Applying temporary patch to engine.py (sglang 0.5.5.post2)") + sed_cmd = ( + "sed -i '/self.send_to_rpc = get_zmq_socket(/,/^ )/c\\" + " if self.server_args.node_rank == 0:\\n" + " self.send_to_rpc = get_zmq_socket(\\n" + " context, zmq.DEALER, self.port_args.rpc_ipc_name, True\\n" + " )\\n" + " else:\\n" + " self.send_to_rpc = None' " + "/sgl-workspace/sglang/python/sglang/srt/entrypoints/engine.py" + ) + result = subprocess.run(sed_cmd, shell=True, capture_output=True, text=True) + if result.returncode != 0: + logging.warning(f"Failed to apply patch: {result.stderr}") + else: + logging.info("Patch applied successfully") + + +def _run_setup_script(setup_script: str | None = None): + """ + Run a setup script in the /configs directory if explicitly provided. + + Args: + setup_script: Custom setup script name (e.g., 'custom-setup.sh'). + If None, no setup script runs. + """ + if not setup_script: + return + + script_path = f"/configs/{setup_script}" + + if os.path.exists(script_path): + logging.info(f"Running setup script: {script_path}") + run_command(f"bash {script_path}") + else: + logging.warning(f"Setup script not found: {script_path}") + + +def setup_prefill_worker( + worker_idx: int, + local_rank: int, + leader_ip: str, + master_ip: str, + nodes_per_worker: int, + gpu_type: str, + multiple_frontends_enabled: bool = False, + profiler: str = "none", + sglang_config_path: str | None = None, + dump_config_path: str | None = None, + setup_script: str | None = None, + use_sglang_router: bool = False, +) -> int: + """Setup the prefill worker.""" + # Setup infrastructure first (if traditional mode) + need_frontend = not multiple_frontends_enabled and worker_idx == 0 and local_rank == 0 + + if not use_sglang_router: + if need_frontend: + setup_head_prefill_node(master_ip) + if not wait_for_etcd(f"http://{master_ip}:{ETCD_CLIENT_PORT}"): + raise RuntimeError("Failed to connect to etcd") + else: + logging.info(f"Setting up prefill worker {worker_idx}, local rank {local_rank}") + if not wait_for_etcd(f"http://{master_ip}:{ETCD_CLIENT_PORT}"): + raise RuntimeError("Failed to connect to etcd") + + # Install dynamo from PyPI (only needed when not using sglang router) + install_dynamo_wheels(gpu_type) + else: + # Install sglang from source when using sglang-router mode (for debugging) + logging.info(f"Setting up prefill worker {worker_idx}, local rank {local_rank} (sglang-router mode)") + install_sglang_from_source() + + # Run custom setup script if provided + _run_setup_script(setup_script) + + # Start frontend AFTER installing dynamo (traditional mode only, not when using sglang router) + if need_frontend and not use_sglang_router: + logging.info("Starting frontend in traditional mode (after dynamo installation)") + + # Open log files for frontend + frontend_stdout = open("/logs/frontend.out", "w") + frontend_stderr = open("/logs/frontend.err", "w") + + frontend_cmd = "python3 -m dynamo.frontend --http-port=8000" + frontend_process = run_command(frontend_cmd, background=True, stdout=frontend_stdout, stderr=frontend_stderr) + if not frontend_process: + raise RuntimeError("Failed to start frontend") + logging.info(f"Frontend started in background (PID: {frontend_process.pid})") + logging.info("Frontend logs: /logs/frontend.out and /logs/frontend.err") + + # Apply temporary patch (for gb200 (not gb300) and h100) + if (gpu_type.startswith("gb200") and not gpu_type.startswith("gb300")) or gpu_type.startswith("h100"): + _patch_sglang_engine() + + # Build and execute SGLang command from YAML config + cmd_to_run = get_gpu_command( + worker_type="prefill", + worker_idx=worker_idx, + sglang_config_path=sglang_config_path, + host_ip=leader_ip, + port=DIST_INIT_PORT, + total_nodes=nodes_per_worker, + rank=local_rank, + profiler=profiler, + dump_config_path=dump_config_path, + use_sglang_router=use_sglang_router, + ) + return run_command(cmd_to_run) + + +def setup_decode_worker( + worker_idx: int, + local_rank: int, + leader_ip: str, + master_ip: str, + nodes_per_worker: int, + gpu_type: str, + profiler: str = "none", + sglang_config_path: str | None = None, + dump_config_path: str | None = None, + setup_script: str | None = None, + use_sglang_router: bool = False, +) -> int: + """Setup the decode worker.""" + logging.info(f"Setting up decode worker {worker_idx}, local rank {local_rank}") + + if not use_sglang_router: + if not wait_for_etcd(f"http://{master_ip}:{ETCD_CLIENT_PORT}"): + raise RuntimeError("Failed to connect to etcd") + + # Install dynamo from PyPI (only needed when not using sglang router) + install_dynamo_wheels(gpu_type) + else: + # Install sglang from source when using sglang-router mode (for debugging) + install_sglang_from_source() + + # Run custom setup script if provided + _run_setup_script(setup_script) + + # Apply temporary patch (for gb200 (not gb300) and h100) + if (gpu_type.startswith("gb200") and not gpu_type.startswith("gb300")) or gpu_type.startswith("h100"): + _patch_sglang_engine() + + # Build and execute SGLang command from YAML config + cmd_to_run = get_gpu_command( + worker_type="decode", + worker_idx=worker_idx, + sglang_config_path=sglang_config_path, + host_ip=leader_ip, + port=DIST_INIT_PORT, + total_nodes=nodes_per_worker, + rank=local_rank, + profiler=profiler, + dump_config_path=dump_config_path, + use_sglang_router=use_sglang_router, + ) + return run_command(cmd_to_run) + + +def setup_router_worker( + router_idx: int, + prefill_ips: list[str], + decode_ips: list[str], + host: str = "0.0.0.0", + port: int = 8000, + server_port: int = 30000, + bootstrap_port: int = 30001, +) -> int: + """Setup an sglang router worker for PD disaggregation. + + Args: + router_idx: Index of this router instance (for logging) + prefill_ips: List of prefill worker leader IPs + decode_ips: List of decode worker leader IPs + host: Host to bind the router to + port: Port to bind the router to + server_port: Port where prefill/decode servers listen (default: 30000) + bootstrap_port: Disaggregation bootstrap port for prefill servers (default: 30001) + + Returns: + Exit code from the router process + """ + logging.info(f"Setting up sglang router {router_idx}") + logging.info(f" Prefill IPs: {prefill_ips}") + logging.info(f" Decode IPs: {decode_ips}") + logging.info(f" Server port: {server_port}, Bootstrap port: {bootstrap_port}") + + # Build router command + router_args = ["python", "-m", "sglang_router.launch_router", "--pd-disaggregation"] + + # Prefill servers need: --prefill http://IP:server_port bootstrap_port + for ip in prefill_ips: + router_args.extend(["--prefill", f"http://{ip}:{server_port}", str(bootstrap_port)]) + + # Decode servers just need: --decode http://IP:server_port + for ip in decode_ips: + router_args.extend(["--decode", f"http://{ip}:{server_port}"]) + + router_args.extend(["--host", host, "--port", str(port)]) + + cmd = " ".join(router_args) + logging.info(f"Router command: {cmd}") + return run_command(cmd) + + +def setup_aggregated_worker( + worker_idx: int, + local_rank: int, + leader_ip: str, + master_ip: str, + nodes_per_worker: int, + gpu_type: str, + multiple_frontends_enabled: bool = False, + profiler: str = "none", + sglang_config_path: str | None = None, + dump_config_path: str | None = None, + setup_script: str | None = None, + use_sglang_router: bool = False, +) -> int: + """Setup the aggregated worker.""" + # Setup infrastructure first (if traditional mode) + need_frontend = not multiple_frontends_enabled and worker_idx == 0 and local_rank == 0 + + if not use_sglang_router: + if need_frontend: + setup_head_prefill_node(master_ip) + if not wait_for_etcd(f"http://{master_ip}:{ETCD_CLIENT_PORT}"): + raise RuntimeError("Failed to connect to etcd") + else: + logging.info(f"Setting up aggregated worker {worker_idx}, local rank {local_rank}") + if not wait_for_etcd(f"http://{master_ip}:{ETCD_CLIENT_PORT}"): + raise RuntimeError("Failed to connect to etcd") + + # Install dynamo from PyPI (only needed when not using sglang router) + install_dynamo_wheels(gpu_type) + else: + # Install sglang from source when using sglang-router mode (for debugging) + logging.info(f"Setting up aggregated worker {worker_idx}, local rank {local_rank} (sglang-router mode)") + install_sglang_from_source() + + # Run custom setup script if provided + _run_setup_script(setup_script) + + # Start frontend AFTER installing dynamo (traditional mode only, not when using sglang router) + if need_frontend and not use_sglang_router: + logging.info("Starting frontend in traditional mode (after dynamo installation)") + + # Open log files for frontend + frontend_stdout = open("/logs/frontend.out", "w") + frontend_stderr = open("/logs/frontend.err", "w") + + frontend_cmd = "python3 -m dynamo.frontend --http-port=8000" + frontend_process = run_command(frontend_cmd, background=True, stdout=frontend_stdout, stderr=frontend_stderr) + if not frontend_process: + raise RuntimeError("Failed to start frontend") + logging.info(f"Frontend started in background (PID: {frontend_process.pid})") + logging.info("Frontend logs: /logs/frontend.out and /logs/frontend.err") + + # Apply temporary patch (for gb200 (not gb300) and h100) + if (gpu_type.startswith("gb200") and not gpu_type.startswith("gb300")) or gpu_type.startswith("h100"): + _patch_sglang_engine() + + # Build and execute SGLang command from YAML config + cmd_to_run = get_gpu_command( + worker_type="aggregated", + sglang_config_path=sglang_config_path, + host_ip=leader_ip, + port=DIST_INIT_PORT, + total_nodes=nodes_per_worker, + rank=local_rank, + profiler=profiler, + dump_config_path=dump_config_path, + use_sglang_router=use_sglang_router, + ) + return run_command(cmd_to_run) diff --git a/src/srtctl/core/backend.py b/src/srtctl/core/backend.py new file mode 100644 index 00000000..cc3016cc --- /dev/null +++ b/src/srtctl/core/backend.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""SGLang backend for SLURM job generation.""" + +import logging +import os +import tempfile +import yaml +from datetime import datetime +from jinja2 import Template +from pathlib import Path + +import srtctl +from srtctl.core.config import get_srtslurm_setting +from srtctl.core.sweep import expand_template + + +class SGLangBackend: + """SGLang backend for distributed serving.""" + + def __init__(self, config: dict, setup_script: str = None): + self.config = config + self.backend_config = config.get("backend", {}) + self.resources = config.get("resources", {}) + self.model = config.get("model", {}) + self.slurm = config.get("slurm", {}) + self.setup_script = setup_script + + def is_disaggregated(self) -> bool: + return self.resources.get("prefill_nodes") is not None + + def get_environment_vars(self, mode: str) -> dict[str, str]: + return self.backend_config.get(f"{mode}_environment", {}) + + def _profiling_type(self) -> str: + return (self.config.get("profiling") or {}).get("type") or "none" + + def _get_enable_config_dump(self) -> bool: + value = self.config.get("enable_config_dump") + if value is not None: + return bool(value) + return self._profiling_type() == "none" + + @staticmethod + def _build_phase_steps_env_str(phase: str, defaults: dict, overrides: dict | None) -> str: + merged = dict(defaults) + if overrides: + merged.update({k: v for k, v in overrides.items() if v is not None}) + + parts: list[str] = [] + if merged.get("start_step") is not None: + parts.append(f"PROFILE_{phase}_START_STEP={merged['start_step']}") + if merged.get("stop_step") is not None: + parts.append(f"PROFILE_{phase}_STOP_STEP={merged['stop_step']}") + return " ".join(parts) + + @staticmethod + def _build_driver_env_str(cfg: dict) -> str: + parts: list[str] = [] + if cfg.get("isl") is not None: + parts.append(f"PROFILE_ISL={cfg['isl']}") + if cfg.get("osl") is not None: + parts.append(f"PROFILE_OSL={cfg['osl']}") + if cfg.get("concurrency") is not None: + parts.append(f"PROFILE_CONCURRENCY={cfg['concurrency']}") + return " ".join(parts) + + def _config_to_flags(self, config: dict) -> list[str]: + lines = [] + for key, value in sorted(config.items()): + flag = key.replace("_", "-") + if isinstance(value, bool): + if value: + lines.append(f" --{flag} \\") + elif isinstance(value, list): + lines.append(f" --{flag} {' '.join(str(v) for v in value)} \\") + else: + lines.append(f" --{flag} {value} \\") + return lines + + def generate_config_file(self, params: dict = None) -> Path | None: + """Generate SGLang YAML config file.""" + if "sglang_config" not in self.backend_config: + return None + + sglang_cfg = self.backend_config["sglang_config"] + if params: + sglang_cfg = expand_template(sglang_cfg, params) + logging.info(f"Expanded config with params: {params}") + + # Validate kebab-case keys + for mode in ["prefill", "decode", "aggregated"]: + if mode in sglang_cfg and sglang_cfg[mode]: + for key in sglang_cfg[mode].keys(): + if "_" in key: + raise ValueError(f"Invalid key '{key}': use '{key.replace('_', '-')}' (kebab-case)") + + result = {mode: sglang_cfg[mode] for mode in ["prefill", "decode", "aggregated"] if mode in sglang_cfg} + for mode in ["prefill", "decode", "aggregated"]: + if env := self.get_environment_vars(mode): + result[f"{mode}_environment"] = env + + fd, temp_path = tempfile.mkstemp(suffix=".yaml", prefix="sglang_config_") + with os.fdopen(fd, "w") as f: + yaml.dump(result, f, default_flow_style=False) + logging.info(f"Generated SGLang config: {temp_path}") + return Path(temp_path) + + def render_command(self, mode: str, config_path: Path = None) -> str: + """Render full SGLang command with all flags inlined.""" + lines = [f"{k}={v} \\" for k, v in (self.get_environment_vars(mode) or {}).items()] + + prof = self._profiling_type() + use_sglang = prof != "none" or self.backend_config.get("use_sglang_router", False) + if prof == "nsys": + lines.append( + "nsys profile -t cuda,nvtx --cuda-graph-trace=node -c cudaProfilerApi --capture-range-end stop --force-overwrite true python3 -m sglang.launch_server \\" + ) + elif use_sglang: + lines.append("python3 -m sglang.launch_server \\") + else: + lines.append("python3 -m dynamo.sglang \\") + + if config_path: + with open(config_path) as f: + sglang_config = yaml.safe_load(f) + lines.extend(self._config_to_flags(sglang_config.get(mode, {}))) + + nnodes = ( + (self.resources["prefill_nodes"] if mode == "prefill" else self.resources["decode_nodes"]) + if self.is_disaggregated() + else self.resources["agg_nodes"] + ) + lines.extend( + [ + " --dist-init-addr $HOST_IP_MACHINE:$PORT \\", + f" --nnodes {nnodes} \\", + " --node-rank $RANK \\", + ] + ) + return "\n".join(lines) + + def generate_slurm_script(self, config_path: Path = None, timestamp: str = None) -> tuple[Path, str]: + """Generate SLURM job script from Jinja template.""" + timestamp = timestamp or datetime.now().strftime("%Y%m%d_%H%M%S") + is_aggregated = not self.is_disaggregated() + + if is_aggregated: + agg_nodes, agg_workers = self.resources["agg_nodes"], self.resources["agg_workers"] + prefill_nodes = decode_nodes = prefill_workers = decode_workers = 0 + total_nodes = agg_nodes + else: + prefill_nodes, decode_nodes = self.resources["prefill_nodes"], self.resources["decode_nodes"] + prefill_workers, decode_workers = self.resources["prefill_workers"], self.resources["decode_workers"] + agg_nodes = agg_workers = 0 + total_nodes = prefill_nodes + decode_nodes + + # SLURM settings + job_name = self.config.get("name", "srtctl-job") + account = self.slurm.get("account") or get_srtslurm_setting("default_account") + partition = self.slurm.get("partition") or get_srtslurm_setting("default_partition") + time_limit = self.slurm.get("time_limit") or get_srtslurm_setting("default_time_limit", "04:00:00") + gpus_per_node = get_srtslurm_setting("gpus_per_node", self.resources.get("gpus_per_node")) + network_interface = get_srtslurm_setting("network_interface", None) + gpu_type = self.resources.get("gpu_type", "h100") + + # Benchmark config + benchmark_config = self.config.get("benchmark", {}) + bench_type = benchmark_config.get("type", "manual") + parsable_config = "" + if bench_type == "sa-bench": + conc = benchmark_config.get("concurrencies") + conc_str = "x".join(str(c) for c in conc) if isinstance(conc, list) else str(conc) + parsable_config = f"{benchmark_config.get('isl')} {benchmark_config.get('osl')} {conc_str} {benchmark_config.get('req_rate', 'inf')}" + elif bench_type == "mmlu": + num_examples = benchmark_config.get("num_examples", 200) + max_tokens = benchmark_config.get("max_tokens", 8192) + repeat = benchmark_config.get("repeat", 8) + num_threads = benchmark_config.get("num_threads", 512) + parsable_config = f"{num_examples} {max_tokens} {repeat} {num_threads}" + elif bench_type == "gpqa": + num_examples = benchmark_config.get("num_examples", 198) + max_tokens = benchmark_config.get("max_tokens", 32768) + repeat = benchmark_config.get("repeat", 8) + num_threads = benchmark_config.get("num_threads", 128) + parsable_config = f"{num_examples} {max_tokens} {repeat} {num_threads}" + elif bench_type == "longbenchv2": + num_examples = benchmark_config.get("num_examples", None) + max_tokens = benchmark_config.get("max_tokens", 16384) + max_context_length = benchmark_config.get("max_context_length", 128000) + num_threads = benchmark_config.get("num_threads", 16) + categories = benchmark_config.get("categories", None) + parsable_config = f"{num_examples} {max_tokens} {max_context_length} {num_threads} {categories}" + + # Paths + srtctl_root = Path(get_srtslurm_setting("srtctl_root") or Path(srtctl.__file__).parent.parent.parent) + config_dir_path = srtctl_root / "configs" + log_dir_path = srtctl_root / "logs" + + profiling_cfg = self.config.get("profiling") or {} + profiling_defaults: dict = {} + + prefill_profile_env = self._build_phase_steps_env_str("PREFILL", profiling_defaults, profiling_cfg.get("prefill")) + decode_profile_env = self._build_phase_steps_env_str("DECODE", profiling_defaults, profiling_cfg.get("decode")) + aggregated_profile_env = self._build_phase_steps_env_str("AGG", profiling_defaults, profiling_cfg.get("aggregated")) + + profiling_driver_env = self._build_driver_env_str(profiling_cfg) + profiler_mode = profiling_cfg.get("type") or "none" + + template_vars = { + "job_name": job_name, + "total_nodes": total_nodes, + "account": account, + "time_limit": time_limit, + "prefill_nodes": prefill_nodes, + "decode_nodes": decode_nodes, + "prefill_workers": prefill_workers, + "decode_workers": decode_workers, + "agg_nodes": agg_nodes, + "agg_workers": agg_workers, + "is_aggregated": is_aggregated, + "model_dir": self.model.get("path"), + "config_dir": str(config_dir_path), + "container_image": self.model.get("container"), + "gpus_per_node": gpus_per_node, + "network_interface": network_interface, + "gpu_type": gpu_type, + "partition": partition, + "enable_multiple_frontends": self.backend_config.get("enable_multiple_frontends", True), + "num_additional_frontends": self.backend_config.get("num_additional_frontends", 9), + "use_sglang_router": self.backend_config.get("use_sglang_router", False), + "sglang_src_dir": self.backend_config.get("sglang_src_dir"), + "do_benchmark": bench_type != "manual", + "benchmark_type": bench_type, + "benchmark_arg": parsable_config, + "timestamp": timestamp, + "enable_config_dump": self._get_enable_config_dump(), + "log_dir_prefix": str(log_dir_path), + "profiler": profiler_mode, + "profiling_driver_env": profiling_driver_env, + "prefill_profile_env": prefill_profile_env, + "decode_profile_env": decode_profile_env, + "aggregated_profile_env": aggregated_profile_env, + "setup_script": self.setup_script, + "use_gpus_per_node_directive": get_srtslurm_setting("use_gpus_per_node_directive", True), + "use_segment_sbatch_directive": get_srtslurm_setting("use_segment_sbatch_directive", True), + "extra_container_mounts": ",".join(self.config.get("extra_mount") or []), + } + + template_name = "job_script_template_agg.j2" if is_aggregated else "job_script_template_disagg.j2" + template_path = srtctl_root / "scripts" / "templates" / template_name + if not template_path.exists(): + raise FileNotFoundError(f"Template not found: {template_path}\nSet 'srtctl_root' in srtslurm.yaml") + + with open(template_path) as f: + rendered_script = Template(f.read()).render(**template_vars) + + fd, temp_path = tempfile.mkstemp(suffix=".sh", prefix="slurm_job_") + with os.fdopen(fd, "w") as f: + f.write(rendered_script) + logging.info(f"Generated SLURM job script: {temp_path}") + return Path(temp_path), rendered_script diff --git a/src/srtctl/core/schema.py b/src/srtctl/core/schema.py index e98c3396..1c290c07 100644 --- a/src/srtctl/core/schema.py +++ b/src/srtctl/core/schema.py @@ -17,6 +17,7 @@ from collections.abc import Iterator, Mapping from dataclasses import field from enum import Enum +import logging from pathlib import Path from typing import ( TYPE_CHECKING,