diff --git a/docs/README.md b/docs/README.md index 9f35a046..41126d0f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -14,18 +14,6 @@ Running large language models across multiple GPUs and nodes requires orchestrat - **Parameter sweeps** - Run grid searches across configurations with a single command - **Profiling support** - Built-in torch/nsys profiling modes -## Architecture Overview - -`srtctl` orchestrates distributed inference using SGLang workers in either **disaggregated** or **aggregated** mode. - -**Disaggregated Mode** separates prefill and decode into specialized workers: - -- Prefill workers handle the initial prompt processing -- Decode workers handle token generation -- Frontend distribution via nginx load balancer (default) or sglang_router - -**Aggregated Mode** runs combined prefill+decode on each worker, simpler but potentially less efficient for high-throughput scenarios. - ## How It Works When you run `srtctl apply -f config.yaml`, the tool: @@ -54,3 +42,4 @@ Once allocated, workers launch inside containers, discover each other through ET - [Parameter Sweeps](sweeps.md) - Run grid searches across configurations - [Profiling](profiling.md) - Performance analysis with torch/nsys - [Analyzing Results](analyzing.md) - Dashboard and visualization +- [SGLang Router](sglang-router.md) - Alternative to Dynamo for PD disaggregation diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 4d7747ad..74214c73 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -1,8 +1,10 @@ # Table of contents -* [Introduction](README.md) -* [Installation](installation.md) -* [Profiling](profiling.md) -* [Monitoring](monitoring.md) -* [Parameter Sweeps](sweeps.md) -* [Analying](analyzing.md) +- [Introduction](README.md) +- [Installation](installation.md) +- [SGLang Router](sglang-router.md) +- [Profiling](profiling.md) +- [Monitoring](monitoring.md) +- [Parameter Sweeps](sweeps.md) +- [Analyzing](analyzing.md) +- [SLURM FAQ](slurm-faq.md) diff --git a/docs/analyzing.md b/docs/analyzing.md index e24459f4..add10688 100644 --- a/docs/analyzing.md +++ b/docs/analyzing.md @@ -6,22 +6,5 @@ uv run streamlit run analysis/dashboard/app.py # Another way to launch dashboard make dashboard ``` -Opens interactive dashboard at http://localhost:8501 - - -## Features - -### šŸ“Š Interactive Dashboard -- **Pareto Analysis** - TPS/GPU vs TPS/User tradeoffs -- **Latency Breakdown** - TTFT, TPOT, ITL across concurrency levels -- **Node Metrics** - Runtime metrics from prefill/decode nodes -- **Config Comparison** - Side-by-side configuration diffs -- **Run Comparison** - Performance deltas between runs - -### šŸš€ SLURM Job Submission - -- Disaggregated (prefill/decode) or aggregated mode -- Multiple frontends with nginx load balancing (default) -- Automated benchmarking with sa-bench -- Job metadata tracking +Opens interactive dashboard at http://localhost:8501 diff --git a/docs/installation.md b/docs/installation.md index 233adccb..2d6079bf 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -18,6 +18,8 @@ pip install -e . ## Gather your cluster user and target partition +These commands might not work on all clusters. You can use AI to figure out the right set of commands for your cluster. + ```bash # user sacctmgr -nP show assoc where user=$(whoami) format=account @@ -27,6 +29,8 @@ sinfo ## Run Setup +If you are trying to deploy onto Grace (GH200, GB200, etc.), you need to use the `aarch64` architecture. Otherwise use `x86_64`. + ```bash make setup ARCH=aarch64 # or ARCH=x86_64 ``` @@ -42,8 +46,6 @@ The setup will: 3. Create `srtslurm.yaml` with your settings 4. Auto-detect and set `srtctl_root` path -Dynamo 0.7.0 is now available on PyPI and will be installed automatically from pip when workers start. - ## Configure srtslurm.yaml After setup, edit `srtslurm.yaml` to add model paths, containers, and cluster-specific settings: @@ -56,7 +58,6 @@ The `model_paths` section maps short aliases to full filesystem paths: model_paths: deepseek-r1: "/mnt/lustre/models/DeepSeek-R1" deepseek-r1-fp4: "/mnt/lustre/models/deepseek-r1-0528-fp4-v2" - llama-70b: "/mnt/lustre/models/Llama-3-70B" ``` Models must be accessible from all compute nodes (typically on a shared filesystem like Lustre or GPFS). @@ -67,15 +68,14 @@ The `containers` section maps version aliases to `.sqsh` container images: ```yaml containers: - latest: "/mnt/containers/lmsysorg+sglang+v0.5.5.sqsh" - stable: "/mnt/containers/lmsysorg+sglang+v0.5.4.sqsh" + container1: "/mnt/containers/lmsysorg+sglang+v0.5.5.sqsh" + container2: "/mnt/containers/lmsysorg+sglang+v0.5.4.sqsh" ``` To create a container image from Docker: ```bash enroot import docker://lmsysorg/sglang:v0.5.5 -mv lmsysorg+sglang+v0.5.5.sqsh /mnt/containers/ ``` ### Cloud Sync (Optional) @@ -91,42 +91,43 @@ cloud: Then use `make sync-to-cloud` or `make sync-run RUN_ID=`. -### Cluster Compatibility Settings - -Some SLURM clusters don't support certain SBATCH directives. If you encounter errors during job submission, you may need to adjust these settings: - -#### GPU Resource Specification - -If you see this error when submitting jobs: +### Complete srtslurm.yaml Reference -``` -sbatch: error: Invalid generic resource (gres) specification -``` - -Your cluster doesn't support the `--gpus-per-node` directive. Disable it in `srtslurm.yaml`: +Here's a complete example of all available options: ```yaml -use_gpus_per_node_directive: false -``` +# Default SLURM settings +default_account: "your-account" +default_partition: "batch" +default_time_limit: "4:00:00" -This will omit the `#SBATCH --gpus-per-node` directive from generated job scripts while keeping all other functionality intact. +# Resource defaults +gpus_per_node: 4 -#### Segment-Based Scheduling +# SLURM directive compatibility +use_gpus_per_node_directive: true # Set false if cluster doesn't support --gpus-per-node +use_segment_sbatch_directive: true # Set false if cluster doesn't support --segment -If you see this error when submitting jobs: +# Path to srtctl repo root (auto-set by make setup) +srtctl_root: "/path/to/srtctl" -``` -sbatch: error: Invalid --segment specification -``` +# Model path aliases +model_paths: + deepseek-r1: "/models/DeepSeek-R1" + llama-70b: "/models/Llama-3-70B" -Your cluster doesn't support the `--segment` directive for topology-aware scheduling. Disable it in `srtslurm.yaml`: +# Container aliases +containers: + latest: "/containers/sglang-latest.sqsh" + stable: "/containers/sglang-stable.sqsh" -```yaml -use_segment_sbatch_directive: false +# Cloud sync settings (optional) +cloud: + endpoint_url: "https://s3.example.com" + bucket: "benchmark-results" + prefix: "my-team/" ``` -The `--segment` directive ensures all allocated nodes are within the same network segment/switch for optimal interconnect performance between prefill and decode workers. If your cluster doesn't support it, SLURM will still allocate nodes but may scatter them across the cluster. - ## Create a Job Config Create `configs/my-job.yaml`: @@ -178,7 +179,7 @@ benchmark: isl: 1024 osl: 1024 concurrencies: [256, 512] - req_rate: "inf" # Request rate, use "inf" for max throughput + req_rate: "inf" ``` ### Backend Options @@ -195,35 +196,6 @@ backend: use_sglang_router: false # Default: false. Use sglang_router for load balancing ``` -## Profiling (torch / nsys) - -You can enable profiling via a top-level `profiling` section in your job YAML: - -```yaml -profiling: - type: "torch" # one of: "none", "torch", "nsys" - isl: 1024 - osl: 128 - concurrency: 24 - start_step: 0 # optional - stop_step: 50 # optional - -benchmark: - type: "manual" # Required - profiling and benchmarking are mutually exclusive -``` - -See [Profiling](profiling.md) for detailed configuration options, constraints, and output file locations. - -## Validate with Dry Run - -Always validate before submitting: - -```bash -srtctl dry-run -f configs/my-job.yaml -``` - -This validates your config, resolves aliases, generates all files, and saves them to `dry-runs/` without submitting to SLURM. - ## Submit the Job ```bash @@ -285,43 +257,6 @@ You can run custom initialization scripts on worker nodes before starting SGLang srtctl apply -f configs/my-job.yaml --setup-script custom-setup.sh ``` -The script will be executed on each worker node (prefill, decode, and aggregated) before installing Dynamo from PyPI and starting the SGLang workers. The script must be located in the `configs/` directory, which is mounted into containers at `/configs/`. +The script will be executed on each worker node (prefill, decode, or aggregated) before installing Dynamo from PyPI and starting the SGLang workers. The script must be located in the `configs/` directory, which is mounted into containers at `/configs/`. **Note**: Setup scripts only run when you explicitly specify `--setup-script`. No default setup script will run if this flag is omitted. - -## Complete srtslurm.yaml Reference - -Here's a complete example of all available options: - -```yaml -# Default SLURM settings -default_account: "your-account" -default_partition: "batch" -default_time_limit: "4:00:00" - -# Resource defaults -gpus_per_node: 4 - -# SLURM directive compatibility -use_gpus_per_node_directive: true # Set false if cluster doesn't support --gpus-per-node -use_segment_sbatch_directive: true # Set false if cluster doesn't support --segment - -# Path to srtctl repo root (auto-set by make setup) -srtctl_root: "/path/to/srtctl" - -# Model path aliases -model_paths: - deepseek-r1: "/models/DeepSeek-R1" - llama-70b: "/models/Llama-3-70B" - -# Container aliases -containers: - latest: "/containers/sglang-latest.sqsh" - stable: "/containers/sglang-stable.sqsh" - -# Cloud sync settings (optional) -cloud: - endpoint_url: "https://s3.example.com" - bucket: "benchmark-results" - prefix: "my-team/" -``` diff --git a/docs/profiling.md b/docs/profiling.md index 182c2d38..3a62743d 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -1,12 +1,16 @@ # Profiling -srtctl supports two profiling backends for performance analysis: **Torch Profiler** and **NVIDIA Nsight Systems (nsys)**. Profiling helps identify bottlenecks in prefill and decode operations. +srtctl supports two profiling backends for performance analysis: **Torch Profiler** and **NVIDIA Nsight Systems (nsys)**. ## Quick Start Add a `profiling` section to your job YAML: ```yaml +# must set benchmark type to "manual" +benchmark: + type: "manual" + profiling: type: "torch" # or "nsys" isl: 1024 diff --git a/docs/sglang-router.md b/docs/sglang-router.md new file mode 100644 index 00000000..533e8f52 --- /dev/null +++ b/docs/sglang-router.md @@ -0,0 +1,210 @@ +# SGLang Router Mode + +This page explains the sglang router mode for prefill-decode (PD) disaggregation, an alternative to the default Dynamo frontend architecture. + +## Overview + +By default, srtctl uses **Dynamo frontends** to coordinate between prefill and decode workers. This requires NATS/ETCD infrastructure and the `dynamo` package. + +**SGLang Router** is an alternative that uses sglang's native `sglang_router` for PD disaggregation. + +| Feature | Dynamo Frontends | SGLang Router | +| -------------- | -------------------------- | -------------------------- | +| Infrastructure | NATS + ETCD + dynamo | sglang_router only | +| Routing | Dynamo's coordination | sglang's native PD routing | +| Scaling | nginx + multiple frontends | nginx + multiple routers | + +## Configuration + +Enable sglang router in your recipe's `backend` section: + +```yaml +backend: + use_sglang_router: true +``` + +That's it. The workers will launch with `sglang.launch_server` instead of `dynamo.sglang`, and the router will handle request distribution. + +## Architecture Modes + +### Single Router (`enable_multiple_frontends: false`) + +The simplest mode - one router on node 0, no nginx: + +```yaml +backend: + use_sglang_router: true + enable_multiple_frontends: false +``` + +``` +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ Node 0 │ +│ ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” │ +│ │ sglang-router │ │ Prefill │ │ Decode │ │ +│ │ :8000 │──│ Worker │──│ Worker │ │ +│ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +- Router directly on port 8000 +- Good for testing or small deployments +- No load balancing overhead + +### Multiple Routers (`enable_multiple_frontends: true`, default) + +Nginx load balances across multiple router instances: + +```yaml +backend: + use_sglang_router: true + enable_multiple_frontends: true # default + num_additional_frontends: 9 # default, total = 1 + 9 = 10 routers +``` + +``` +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ Node 0 Node 1 Node 2 │ +│ ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” │ +│ │ nginx │ │ sglang-router │ │ sglang- │ │ sglang- │ │ +│ │ :8000 │──│ :30080 │ │ router │ │ router │ │ +│ ā””ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”˜ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ │ :30080 │ │ :30080 │ │ +│ │ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ │ +│ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +│ │ +│ ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” │ +│ │ Prefill │ │ Prefill │ │ Decode │ │ Decode │ │ +│ │ Worker 0 │ │ Worker 1 │ │ Worker 0 │ │ Worker 1 │ │ +│ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +- nginx on node 0 listens on port 8000 (public) +- Routers listen on port 30080 (internal) +- nginx round-robins requests to routers +- Routers distributed across nodes using same logic as Dynamo frontends + +## How Router Distribution Works + +The `num_additional_frontends` setting controls how many additional routers spawn beyond the first: + +| Setting | Total Routers | Distribution | +| ----------------------------- | ------------- | -------------------------------- | +| `num_additional_frontends: 0` | 1 | Node 0 only | +| `num_additional_frontends: 4` | 5 | Node 0 + 4 distributed | +| `num_additional_frontends: 9` | 10 | Node 0 + 9 distributed (default) | + +Routers are distributed across available nodes using ceiling division: + +``` +nodes_per_router = ceil((total_nodes - 1) / num_additional_frontends) +``` + +## Port Configuration + +### Bootstrap Port + +The sglang router needs the **disaggregation bootstrap port** to connect to prefill workers. This must match the `disaggregation-bootstrap-port` in your sglang config: + +```yaml +backend: + sglang_config: + prefill: + disaggregation-bootstrap-port: 30001 # Must match + # ... other config + decode: + disaggregation-bootstrap-port: 30001 # Must match + # ... other config +``` + +The default bootstrap port is `30001` (matching most recipes). If you use a different port, ensure it's consistent across prefill and decode configs. + +### Server Port + +Workers listen on port `30000` by default. This is standard sglang behavior and doesn't need configuration. + +## Complete Example + +Here's a full recipe using sglang router: + +```yaml +name: "deepseek-r1-sglang-router" + +model: + path: "deepseek-r1-fp4" + container: "sglang-latest" + precision: "fp4" + +resources: + gpu_type: "gb300" + gpus_per_node: 4 + prefill_nodes: 2 + prefill_workers: 2 + decode_nodes: 2 + decode_workers: 2 + +backend: + use_sglang_router: true + enable_multiple_frontends: true + num_additional_frontends: 3 # 4 total routers + + sglang_config: + prefill: + model-path: /model/ + tensor-parallel-size: 4 + disaggregation-mode: prefill + disaggregation-bootstrap-port: 30001 + disaggregation-transfer-backend: nixl + # ... other prefill settings + + decode: + model-path: /model/ + tensor-parallel-size: 4 + disaggregation-mode: decode + disaggregation-bootstrap-port: 30001 + disaggregation-transfer-backend: nixl + # ... other decode settings + +benchmark: + type: "sa-bench" + isl: 128000 + osl: 8000 + concurrencies: "16x32" +``` + +## Troubleshooting + +### Port Conflicts + +If you see `bind() to 0.0.0.0:8000 failed (Address already in use)`: + +- This means nginx and a router are both trying to use port 8000 +- Ensure you're using the latest template (routers use port 30080 internally) + +### Router Not Connecting to Workers + +Check that: + +1. `disaggregation-bootstrap-port` matches in prefill/decode configs +2. Workers are fully started before router tries to connect +3. Network connectivity between router and worker nodes + +### Benchmark Can't Reach Endpoint + +The benchmark connects to `http://:8000`. Ensure: + +- nginx is running (if `enable_multiple_frontends: true`) +- Router is running (if `enable_multiple_frontends: false`) +- Port 8000 is accessible + +## Comparison with Dynamo + +| Aspect | Dynamo Frontends | SGLang Router | +| -------------- | ----------------------------------- | ------------------------ | +| **Startup** | Slower (NATS/ETCD + dynamo install) | Faster (just sglang) | +| **Complexity** | More moving parts | Simpler | +| **Maturity** | Production-tested | Newer | +| **Config** | Via dynamo.sglang | Via sglang.launch_server | +| **Scaling** | Same nginx approach | Same nginx approach | + +Both modes support the same `enable_multiple_frontends` and `num_additional_frontends` settings for horizontal scaling. diff --git a/docs/slurm-faq.md b/docs/slurm-faq.md new file mode 100644 index 00000000..fa442327 --- /dev/null +++ b/docs/slurm-faq.md @@ -0,0 +1,37 @@ +# SLURM FAQ + +## Cluster Compatibility Settings + +Some SLURM clusters don't support certain SBATCH directives. If you encounter errors during job submission, you may need to adjust these settings in your `srtslurm.yaml`. + +## GPU Resource Specification + +If you see this error when submitting jobs: + +``` +sbatch: error: Invalid generic resource (gres) specification +``` + +Your cluster doesn't support the `--gpus-per-node` directive. Disable it: + +```yaml +use_gpus_per_node_directive: false +``` + +This will omit the `#SBATCH --gpus-per-node` directive from generated job scripts while keeping all other functionality intact. + +## Segment-Based Scheduling + +If you see this error when submitting jobs: + +``` +sbatch: error: Invalid --segment specification +``` + +Your cluster doesn't support the `--segment` directive for topology-aware scheduling. Disable it: + +```yaml +use_segment_sbatch_directive: false +``` + +The `--segment` directive ensures all allocated nodes are within the same network segment/switch for optimal interconnect performance between prefill and decode workers. If your cluster doesn't support it, SLURM will still allocate nodes but may scatter them across the cluster. diff --git a/scripts/benchmarks/sa-bench/bench.sh b/scripts/benchmarks/sa-bench/bench.sh index 70b8aa21..882cb026 100755 --- a/scripts/benchmarks/sa-bench/bench.sh +++ b/scripts/benchmarks/sa-bench/bench.sh @@ -24,14 +24,15 @@ chosen_osl=$6 concurrency_list=$7 IFS='x' read -r -a chosen_concurrencies <<< "$concurrency_list" chosen_req_rate=$8 +use_sglang_router=${9:-false} -echo "Config ${chosen_isl}; ${chosen_osl}; ${chosen_concurrencies[@]}; ${chosen_req_rate}" +echo "Config ${chosen_isl}; ${chosen_osl}; ${chosen_concurrencies[@]}; ${chosen_req_rate}; sglang_router=${use_sglang_router}" wait_for_model_timeout=3600 # 1 hour wait_for_model_check_interval=5 # check interval -> 5s wait_for_model_report_interval=60 # wait_for_model report interval -> 60s -wait_for_model $head_node $head_port $n_prefill $n_decode $wait_for_model_check_interval $wait_for_model_timeout $wait_for_model_report_interval +wait_for_model $head_node $head_port $n_prefill $n_decode $wait_for_model_check_interval $wait_for_model_timeout $wait_for_model_report_interval $use_sglang_router # run a quick curl request against the model to do an accuracy spot check curl http://${head_node}:${head_port}/v1/chat/completions -H "Content-Type: application/json" -d '{ diff --git a/scripts/templates/job_script_template_disagg.j2 b/scripts/templates/job_script_template_disagg.j2 index ddcf87ea..fbccd795 100755 --- a/scripts/templates/job_script_template_disagg.j2 +++ b/scripts/templates/job_script_template_disagg.j2 @@ -338,9 +338,8 @@ echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --overlap --pty bas {% endif %} {% raw %} -# Launch sglang router when enabled +# Launch sglang router(s) when enabled {% endraw %}{% if use_sglang_router %}{% raw %} -echo "Launching sglang router on ${nodes[0]}" # Collect leader IPs for prefill and decode PREFILL_LEADER_IPS=() for idx in "${prefill_leaders[@]}"; do @@ -355,19 +354,122 @@ for idx in "${decode_leaders[@]}"; do DECODE_LEADER_IPS+=("$ip") done -ROUTER_ARGS="--pd-disaggregation" -for ip in "${PREFILL_LEADER_IPS[@]}"; do - ROUTER_ARGS="$ROUTER_ARGS --prefill http://${ip}:30000" -done -for ip in "${DECODE_LEADER_IPS[@]}"; do - ROUTER_ARGS="$ROUTER_ARGS --decode http://${ip}:30000" -done +PREFILL_IPS_STR=$(IFS=,; echo "${PREFILL_LEADER_IPS[*]}") +DECODE_IPS_STR=$(IFS=,; echo "${DECODE_LEADER_IPS[*]}") + +{% endraw %} +{% if enable_multiple_frontends %} +{% raw %} +# Multiple router architecture (mirrors dynamo frontend scaling) +# Node 0: nginx load balancer + first router +# Node 1+: additional routers distributed across worker nodes + +NGINX_NODE=${nodes[0]} +NGINX_IP=$(get_node_ip "$NGINX_NODE" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + +# Build router host/IP lists +router_hosts=() +router_ips=() + +# First router always on node 0 +router_hosts+=("$NGINX_NODE") +router_ips+=("$NGINX_IP") + +# Add additional routers (uses same num_additional_frontends setting as dynamo) +{% endraw %}ADDITIONAL_ROUTERS={{ num_additional_frontends }}{% raw %} +if [ "$ADDITIONAL_ROUTERS" -gt 0 ]; then + # Calculate which nodes get additional routers + # Distribute additional routers across nodes, starting from node 1 + nodes_per_router=$(( (TOTAL_NODES - 1 + ADDITIONAL_ROUTERS - 1) / ADDITIONAL_ROUTERS )) # ceil division + router_node_idx=1 # Start from node 1 (node 0 already has first router) + + for i in $(seq 1 $ADDITIONAL_ROUTERS); do + if [ $router_node_idx -lt $TOTAL_NODES ]; then + node_name=${nodes[$router_node_idx]} + node_ip=$(get_node_ip "$node_name" "$SLURM_JOB_ID" "$NETWORK_INTERFACE") + router_hosts+=("$node_name") + router_ips+=("$node_ip") + echo "Additional router $i on node $router_node_idx: $node_name ($node_ip)" + router_node_idx=$((router_node_idx + nodes_per_router)) + fi + done +fi + +echo "Router hosts: ${router_hosts[@]}" +echo "Router IPs: ${router_ips[@]}" + +# Generate nginx configuration for router load balancing +# Routers use internal port 30080, nginx exposes on 8000 +ROUTER_INTERNAL_PORT=30080 +ROUTER_LIST=$(printf "'%s'," "${router_ips[@]}") +ROUTER_LIST="[${ROUTER_LIST%,}]" +export ROUTER_LIST ROUTER_INTERNAL_PORT SCRIPT_DIR LOG_DIR +python3 - <<'PY' +import os +from jinja2 import Template + +template_path = os.path.join(os.environ['SCRIPT_DIR'], 'templates/nginx.conf.j2') +output_path = os.path.join(os.environ['LOG_DIR'], 'nginx.conf') + +with open(template_path, 'r') as f: + tmpl = Template(f.read()) + +router_hosts = eval(os.environ['ROUTER_LIST']) +backend_port = int(os.environ['ROUTER_INTERNAL_PORT']) +config = tmpl.render(frontend_hosts=router_hosts, backend_port=backend_port) + +with open(output_path, 'w') as f: + f.write(config) +PY + +# Launch nginx on node 0 +echo "Launching nginx for router load balancing on ${NGINX_NODE}" +cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$NGINX_NODE --output=${LOG_DIR}/${NGINX_NODE}_nginx.out python /scripts/worker_setup.py --worker_type nginx --nginx_config /logs/nginx.conf ${WORKER_ARGS}" +echo "$cmd" +$cmd & + +# Launch first router on node 0 (with nginx) +# Router listens on internal port, nginx proxies from 8000 +echo "Launching sglang-router 0 on ${NGINX_NODE} (internal port ${ROUTER_INTERNAL_PORT})" +cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$NGINX_NODE --output=${LOG_DIR}/${NGINX_NODE}_router_0.out python /scripts/worker_setup.py --worker_type sglang-router --worker_idx 0 --prefill-ips ${PREFILL_IPS_STR} --decode-ips ${DECODE_IPS_STR} --router-port ${ROUTER_INTERNAL_PORT} ${WORKER_ARGS}" +echo "$cmd" +$cmd & + +# Launch additional routers on designated nodes +if [ "$ADDITIONAL_ROUTERS" -gt 0 ]; then + router_idx=1 + nodes_per_router=$(( (TOTAL_NODES - 1 + ADDITIONAL_ROUTERS - 1) / ADDITIONAL_ROUTERS )) + router_node_idx=1 + + for i in $(seq 1 $ADDITIONAL_ROUTERS); do + if [ $router_node_idx -lt $TOTAL_NODES ]; then + node=${nodes[$router_node_idx]} + echo "Launching sglang-router $router_idx on node $router_node_idx: $node (internal port ${ROUTER_INTERNAL_PORT})" + cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_router_${router_idx}.out python /scripts/worker_setup.py --worker_type sglang-router --worker_idx ${router_idx} --prefill-ips ${PREFILL_IPS_STR} --decode-ips ${DECODE_IPS_STR} --router-port ${ROUTER_INTERNAL_PORT} ${WORKER_ARGS}" + echo "$cmd" + $cmd & + router_idx=$((router_idx + 1)) + router_node_idx=$((router_node_idx + nodes_per_router)) + fi + done +fi +TOTAL_ROUTERS=$((1 + ADDITIONAL_ROUTERS)) +echo "Frontend available at: http://${NGINX_NODE}:8000 (nginx load balancing ${TOTAL_ROUTERS} sglang-routers)" +{% endraw %} +{% else %} +{% raw %} +# Single router architecture - no nginx, router directly on node 0 port 8000 ROUTER_NODE=${nodes[0]} -cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$ROUTER_NODE --output=${LOG_DIR}/${ROUTER_NODE}_router.out python -m sglang_router.launch_router $ROUTER_ARGS --host 0.0.0.0 --port 8000" +echo "Launching single sglang-router on ${ROUTER_NODE} (port 8000)" +cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$ROUTER_NODE --output=${LOG_DIR}/${ROUTER_NODE}_router.out python /scripts/worker_setup.py --worker_type sglang-router --worker_idx 0 --prefill-ips ${PREFILL_IPS_STR} --decode-ips ${DECODE_IPS_STR} --router-port 8000 ${WORKER_ARGS}" echo "$cmd" $cmd & -{% endraw %}{% endif %} + +echo "Frontend available at: http://${ROUTER_NODE}:8000" +{% endraw %} +{% endif %} +{% endif %} {% raw %} echo "" @@ -380,10 +482,11 @@ echo "scancel $SLURM_JOB_ID" BENCHMARK_TYPE={{ benchmark_type }} BENCHMARK_ARGS="{{ benchmark_arg }}" +USE_SGLANG_ROUTER={{ "true" if use_sglang_router else "false" }} {% if do_benchmark %} {% raw %} -srun --nodes=1 --ntasks=1 $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --output=${LOG_DIR}/benchmark.out --overlap bash /scripts/benchmarks/${BENCHMARK_TYPE}/bench.sh $PREFILL_WORKERS $DECODE_WORKERS $PREFILL_GPUS $DECODE_GPUS ${BENCHMARK_ARGS} & +srun --nodes=1 --ntasks=1 $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --output=${LOG_DIR}/benchmark.out --overlap bash /scripts/benchmarks/${BENCHMARK_TYPE}/bench.sh $PREFILL_WORKERS $DECODE_WORKERS $PREFILL_GPUS $DECODE_GPUS ${BENCHMARK_ARGS} ${USE_SGLANG_ROUTER} & {% endraw %} {% endif %} diff --git a/scripts/templates/nginx.conf.j2 b/scripts/templates/nginx.conf.j2 index c66f2c70..e06e5227 100644 --- a/scripts/templates/nginx.conf.j2 +++ b/scripts/templates/nginx.conf.j2 @@ -6,7 +6,7 @@ http { access_log off; upstream backend_servers { {% for frontend_host in frontend_hosts %} - server {{ frontend_host }}:8000; + server {{ frontend_host }}:{{ backend_port | default(8000) }}; {% endfor %} } diff --git a/scripts/utils/benchmark_utils.sh b/scripts/utils/benchmark_utils.sh index c2f584b9..09a6e8fa 100755 --- a/scripts/utils/benchmark_utils.sh +++ b/scripts/utils/benchmark_utils.sh @@ -11,18 +11,31 @@ wait_for_model() { local poll=${5:-1} local timeout=${6:-600} local report_every=${7:-60} + local use_sglang_router=${8:-false} local health_addr="http://${model_host}:${model_port}/health" - echo "Polling ${health_addr} every ${poll} seconds to check whether ${n_prefill} prefills and ${n_decode} decodes are alive" + local workers_addr="http://${model_host}:${model_port}/workers" + + if [[ $use_sglang_router == "true" ]]; then + echo "Polling ${workers_addr} every ${poll} seconds to check whether ${n_prefill} prefills and ${n_decode} decodes are alive (sglang router mode)" + else + echo "Polling ${health_addr} every ${poll} seconds to check whether ${n_prefill} prefills and ${n_decode} decodes are alive" + fi local start_ts=$(date +%s) local report_ts=$(date +%s) while :; do - # Curl timeout - our primary use case here is to launch it at the first node (localhost), so no timeout is needed. - curl_result=$(curl ${health_addr} 2>/dev/null) - # Python path - Use of `check_server_health.py` is self-constrained outside of any packaging. - check_result=$(python3 /scripts/utils/check_server_health.py $n_prefill $n_decode <<< $curl_result) + if [[ $use_sglang_router == "true" ]]; then + # sglang router: use /workers endpoint for worker counts + curl_result=$(curl ${workers_addr} 2>/dev/null) + check_result=$(python3 /scripts/utils/check_server_health.py $n_prefill $n_decode --sglang-router <<< $curl_result) + else + # dynamo: use /health endpoint + curl_result=$(curl ${health_addr} 2>/dev/null) + check_result=$(python3 /scripts/utils/check_server_health.py $n_prefill $n_decode <<< $curl_result) + fi + if [[ $check_result == *"Model is ready."* ]]; then echo $check_result return 0 diff --git a/scripts/utils/check_server_health.py b/scripts/utils/check_server_health.py index 218f75a0..5cfc0f79 100644 --- a/scripts/utils/check_server_health.py +++ b/scripts/utils/check_server_health.py @@ -2,45 +2,53 @@ # SPDX-License-Identifier: Apache-2.0 # pytest: skip-file +import argparse import json import sys """ -A file that parses the response of `curl :/health` endpoint -to check whether the server is ready to be benchmarked. +A file that parses the response of server health endpoints to check whether the server is ready. + +Supports two modes: +- Dynamo frontend: parses /health endpoint JSON with 'instances' key +- SGLang router: parses /workers endpoint JSON with 'stats' key Usage: ```bash +# Dynamo mode (default) curl_result=$(curl "${host_ip}:${host_port}/health" 2> /dev/null) check_result=$(python3 check_server_health.py $N_PREFILL $N_DECODE <<< $curl_result) -# ... then do subsequent processing for check_result ... +# SGLang router mode +curl_result=$(curl "${host_ip}:${host_port}/workers" 2> /dev/null) +check_result=$(python3 check_server_health.py $N_PREFILL $N_DECODE --sglang-router <<< $curl_result) ``` """ -def check_server_health(expected_n_prefill, expected_n_decode, response): - """ - Checks the health of the server's response - and ensures that the number of spinned up prefill & decode - matches our expectation. - --- - Parameter: - - expected_n_prefill: string (expect integer), number of expected prefill workers. - - expected_n_decode: string (expect integer), number of expected decode workers. - - response: string, formatted `curl /health` curl results, - should be JSON-parsable - - Returns: - string, a pretty-printable string that tell the current status. - """ - if not (expected_n_prefill.isnumeric() and expected_n_decode.isnumeric()): - return f"Got unparsable expected prefill / decode value: {expected_n_prefill} & {expected_n_decode} should be string" +def check_sglang_router_health(expected_n_prefill: int, expected_n_decode: int, response: str) -> str: + """Check health using sglang router /workers endpoint.""" + try: + decoded_response = json.loads(response) + except json.JSONDecodeError: + return f"Got invalid response from server that leads to JSON Decode error: {response}" - expected_n_prefill = int(expected_n_prefill) - expected_n_decode = int(expected_n_decode) + if "stats" not in decoded_response: + return f"Key 'stats' not found in response: {response}" + stats = decoded_response["stats"] + actual_prefill = stats.get("prefill_count", 0) + actual_decode = stats.get("decode_count", 0) + + if actual_prefill >= expected_n_prefill and actual_decode >= expected_n_decode: + return f"Model is ready. Have {actual_prefill} prefills and {actual_decode} decodes." + else: + return f"Model is not ready, waiting for {expected_n_prefill - actual_prefill} prefills and {expected_n_decode - actual_decode} decodes. Have {actual_prefill} prefills and {actual_decode} decodes." + + +def check_dynamo_health(expected_n_prefill: int, expected_n_decode: int, response: str) -> str: + """Check health using dynamo frontend /health endpoint.""" try: decoded_response = json.loads(response) except json.JSONDecodeError: @@ -71,28 +79,37 @@ def check_server_health(expected_n_prefill, expected_n_decode, response): return f"Model is not ready, waiting for {expected_n_prefill} prefills and {expected_n_decode} decodes to spin up. Response: {response}" -if __name__ == "__main__": +def check_server_health( + expected_n_prefill: str, expected_n_decode: str, response: str, sglang_router: bool = False +) -> str: """ - Usage - - provide the expected number of prefill / decode as sys args - and then provide the `curl` response as an input. - E.g.: - ```bash - curl_result=$(curl "${host_ip}:${host_port}/health" 2> /dev/null) - check_result=$(python3 check_server_health.py $N_PREFILL $N_DECODE <<< $curl_result) - - # ... then do subsequent processing for check_result ... - ``` + Checks the health of the server's response and ensures worker counts match expectation. """ + if not (expected_n_prefill.isnumeric() and expected_n_decode.isnumeric()): + return f"Got unparsable expected prefill / decode value: {expected_n_prefill} & {expected_n_decode} should be numeric" + + n_prefill = int(expected_n_prefill) + n_decode = int(expected_n_decode) - expected_n_prefill = sys.argv[1] - expected_n_decode = sys.argv[2] + if sglang_router: + return check_sglang_router_health(n_prefill, n_decode, response) + else: + return check_dynamo_health(n_prefill, n_decode, response) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Check server health for benchmarking") + parser.add_argument("n_prefill", help="Expected number of prefill workers") + parser.add_argument("n_decode", help="Expected number of decode workers") + parser.add_argument("--sglang-router", action="store_true", help="Use sglang router /workers format") + args = parser.parse_args() response = sys.stdin.read() print( check_server_health( - expected_n_prefill=expected_n_prefill, - expected_n_decode=expected_n_decode, + expected_n_prefill=args.n_prefill, + expected_n_decode=args.n_decode, response=response, + sglang_router=args.sglang_router, ) ) diff --git a/scripts/worker_setup.py b/scripts/worker_setup.py index 6a4fe9b8..ad2b35eb 100644 --- a/scripts/worker_setup.py +++ b/scripts/worker_setup.py @@ -26,6 +26,7 @@ setup_logging, setup_nginx_worker, setup_prefill_worker, + setup_router_worker, ) @@ -66,7 +67,7 @@ def _parse_command_line_args(args: list[str] | None = None) -> argparse.Namespac ) parser.add_argument( "--worker_type", - choices=["decode", "prefill", "frontend", "nginx", "aggregated"], + choices=["decode", "prefill", "frontend", "nginx", "aggregated", "sglang-router"], required=True, help="Type of worker to run", ) @@ -89,6 +90,36 @@ def _parse_command_line_args(args: list[str] | None = None) -> argparse.Namespac help="Path to nginx configuration file (required for nginx worker type)", ) + # sglang-router-specific arguments + parser.add_argument( + "--prefill-ips", + type=str, + help="Comma-separated list of prefill worker leader IPs (required for sglang-router worker type)", + ) + parser.add_argument( + "--decode-ips", + type=str, + help="Comma-separated list of decode worker leader IPs (required for sglang-router worker type)", + ) + parser.add_argument( + "--router-port", + type=int, + default=8000, + help="Port for the router to listen on (default: 8000)", + ) + parser.add_argument( + "--server-port", + type=int, + default=30000, + help="Port where prefill/decode servers listen (default: 30000)", + ) + parser.add_argument( + "--bootstrap-port", + type=int, + default=30001, + help="Disaggregation bootstrap port for prefill servers (default: 30001)", + ) + parser.add_argument( "--multiple-frontends-enabled", action="store_true", @@ -158,6 +189,13 @@ def _validate_args(args: argparse.Namespace) -> None: if args.worker_type == "nginx" and not args.nginx_config: raise ValueError("--nginx_config is required for nginx worker type") + # Validate sglang-router-specific arguments + if args.worker_type == "sglang-router": + if not args.prefill_ips: + raise ValueError("--prefill-ips is required for sglang-router worker type") + if not args.decode_ips: + raise ValueError("--decode-ips is required for sglang-router worker type") + def main(input_args: list[str] | None = None): setup_logging() @@ -225,6 +263,18 @@ def main(input_args: list[str] | None = None): args.setup_script, args.use_sglang_router, ) + elif args.worker_type == "sglang-router": + prefill_ips = [ip.strip() for ip in args.prefill_ips.split(",") if ip.strip()] + decode_ips = [ip.strip() for ip in args.decode_ips.split(",") if ip.strip()] + setup_router_worker( + router_idx=args.worker_idx or 0, + prefill_ips=prefill_ips, + decode_ips=decode_ips, + host="0.0.0.0", + port=args.router_port, + server_port=args.server_port, + bootstrap_port=args.bootstrap_port, + ) logging.info(f"{args.worker_type.capitalize()} worker setup complete") diff --git a/scripts/worker_setup/__init__.py b/scripts/worker_setup/__init__.py index 2c5b8210..125607c8 100644 --- a/scripts/worker_setup/__init__.py +++ b/scripts/worker_setup/__init__.py @@ -7,7 +7,7 @@ from .environment import setup_env from .infrastructure import setup_frontend_worker, setup_head_prefill_node, setup_nginx_worker from .utils import setup_logging, wait_for_etcd -from .worker import setup_aggregated_worker, setup_decode_worker, setup_prefill_worker +from .worker import setup_aggregated_worker, setup_decode_worker, setup_prefill_worker, setup_router_worker __all__ = [ # Command building @@ -27,4 +27,5 @@ "setup_aggregated_worker", "setup_decode_worker", "setup_prefill_worker", + "setup_router_worker", ] diff --git a/scripts/worker_setup/command.py b/scripts/worker_setup/command.py index 6a4a5aa5..9bc81d03 100644 --- a/scripts/worker_setup/command.py +++ b/scripts/worker_setup/command.py @@ -17,13 +17,14 @@ def build_sglang_command_from_yaml( rank: int, profiler: str = "none", dump_config_path: str | None = None, + use_sglang_router: bool = False, ) -> str: """Build SGLang command using native YAML config support. dynamo.sglang supports reading config from YAML: python3 -m dynamo.sglang --config file.yaml --config-key prefill - sglang.launch_server (profiling mode) requires explicit flags: + sglang.launch_server (profiling mode or sglang router mode) requires explicit flags: python3 -m sglang.launch_server --model-path /model/ --tp 4 ... Args: @@ -34,6 +35,7 @@ def build_sglang_command_from_yaml( total_nodes: Total number of nodes rank: Node rank (0-indexed) profiler: Profiling method: "none", "torch", or "nsys" + use_sglang_router: Use sglang.launch_server instead of dynamo.sglang Returns: Full command string ready to execute @@ -57,11 +59,13 @@ def build_sglang_command_from_yaml( if profiler == "torch": env_exports.append(f"export SGLANG_TORCH_PROFILER_DIR=/logs/profiles/{config_key}") - # Determine Python module based on profiling mode - python_module = "sglang.launch_server" if profiler != "none" else "dynamo.sglang" + # Determine Python module based on profiling mode or sglang router mode + # Use sglang.launch_server when profiling OR when using sglang router (no dynamo) + use_launch_server = profiler != "none" or use_sglang_router + python_module = "sglang.launch_server" if use_launch_server else "dynamo.sglang" nsys_prefix = f"nsys profile -t cuda,nvtx --cuda-graph-trace=node -c cudaProfilerApi --capture-range-end stop --force-overwrite true -o /logs/profiles/{config_key}_{rank}" - if profiler != "none": + if use_launch_server: # Profiling mode: inline all flags (sglang.launch_server doesn't support --config) mode_config = sglang_config.get(config_key, {}) # Wrap with NSYS on all ranks; outputs are isolated per-rank @@ -103,8 +107,8 @@ def build_sglang_command_from_yaml( "--host 0.0.0.0", ] - # Add dump-config-to flag if provided - if dump_config_path: + # Add dump-config-to flag if provided (not supported by sglang.launch_server) + if dump_config_path and not use_sglang_router: cmd_parts.append(f"--dump-config-to {dump_config_path}") # Combine environment exports and command @@ -149,6 +153,7 @@ def get_gpu_command( rank: int, profiler: str = "none", dump_config_path: str | None = None, + use_sglang_router: bool = False, ) -> str: """Generate command to run SGLang worker using YAML config. @@ -160,6 +165,7 @@ def get_gpu_command( total_nodes: Total number of nodes rank: Node rank (0-indexed) profiler: Profiling method: "none", "torch", or "nsys" + use_sglang_router: Use sglang.launch_server instead of dynamo.sglang Returns: Command string to execute @@ -169,5 +175,5 @@ def get_gpu_command( logging.info(f"Building command from YAML config: {sglang_config_path}") return build_sglang_command_from_yaml( - worker_type, sglang_config_path, host_ip, port, total_nodes, rank, profiler, dump_config_path + worker_type, sglang_config_path, host_ip, port, total_nodes, rank, profiler, dump_config_path, use_sglang_router ) diff --git a/scripts/worker_setup/worker.py b/scripts/worker_setup/worker.py index 42f92469..95d6594e 100644 --- a/scripts/worker_setup/worker.py +++ b/scripts/worker_setup/worker.py @@ -100,14 +100,14 @@ def setup_prefill_worker( if not wait_for_etcd(f"http://{master_ip}:{ETCD_CLIENT_PORT}"): raise RuntimeError("Failed to connect to etcd") - # Install dynamo from PyPI - install_dynamo_wheels(gpu_type) + # Install dynamo from PyPI (only needed when not using sglang router) + install_dynamo_wheels(gpu_type) # Run custom setup script if provided _run_setup_script(setup_script) - # Start frontend AFTER installing dynamo (traditional mode only) - if need_frontend: + # Start frontend AFTER installing dynamo (traditional mode only, not when using sglang router) + if need_frontend and not use_sglang_router: logging.info("Starting frontend in traditional mode (after dynamo installation)") # Open log files for frontend @@ -135,6 +135,7 @@ def setup_prefill_worker( rank=local_rank, profiler=profiler, dump_config_path=dump_config_path, + use_sglang_router=use_sglang_router, ) return run_command(cmd_to_run) @@ -159,8 +160,8 @@ def setup_decode_worker( if not wait_for_etcd(f"http://{master_ip}:{ETCD_CLIENT_PORT}"): raise RuntimeError("Failed to connect to etcd") - # Install dynamo from PyPI - install_dynamo_wheels(gpu_type) + # Install dynamo from PyPI (only needed when not using sglang router) + install_dynamo_wheels(gpu_type) # Run custom setup script if provided _run_setup_script(setup_script) @@ -179,10 +180,57 @@ def setup_decode_worker( rank=local_rank, profiler=profiler, dump_config_path=dump_config_path, + use_sglang_router=use_sglang_router, ) return run_command(cmd_to_run) +def setup_router_worker( + router_idx: int, + prefill_ips: list[str], + decode_ips: list[str], + host: str = "0.0.0.0", + port: int = 8000, + server_port: int = 30000, + bootstrap_port: int = 30001, +) -> int: + """Setup an sglang router worker for PD disaggregation. + + Args: + router_idx: Index of this router instance (for logging) + prefill_ips: List of prefill worker leader IPs + decode_ips: List of decode worker leader IPs + host: Host to bind the router to + port: Port to bind the router to + server_port: Port where prefill/decode servers listen (default: 30000) + bootstrap_port: Disaggregation bootstrap port for prefill servers (default: 30001) + + Returns: + Exit code from the router process + """ + logging.info(f"Setting up sglang router {router_idx}") + logging.info(f" Prefill IPs: {prefill_ips}") + logging.info(f" Decode IPs: {decode_ips}") + logging.info(f" Server port: {server_port}, Bootstrap port: {bootstrap_port}") + + # Build router command + router_args = ["python", "-m", "sglang_router.launch_router", "--pd-disaggregation"] + + # Prefill servers need: --prefill http://IP:server_port bootstrap_port + for ip in prefill_ips: + router_args.extend(["--prefill", f"http://{ip}:{server_port}", str(bootstrap_port)]) + + # Decode servers just need: --decode http://IP:server_port + for ip in decode_ips: + router_args.extend(["--decode", f"http://{ip}:{server_port}"]) + + router_args.extend(["--host", host, "--port", str(port)]) + + cmd = " ".join(router_args) + logging.info(f"Router command: {cmd}") + return run_command(cmd) + + def setup_aggregated_worker( worker_idx: int, local_rank: int, @@ -211,14 +259,14 @@ def setup_aggregated_worker( if not wait_for_etcd(f"http://{master_ip}:{ETCD_CLIENT_PORT}"): raise RuntimeError("Failed to connect to etcd") - # Install dynamo from PyPI - install_dynamo_wheels(gpu_type) + # Install dynamo from PyPI (only needed when not using sglang router) + install_dynamo_wheels(gpu_type) # Run custom setup script if provided _run_setup_script(setup_script) - # Start frontend AFTER installing dynamo (traditional mode only) - if need_frontend: + # Start frontend AFTER installing dynamo (traditional mode only, not when using sglang router) + if need_frontend and not use_sglang_router: logging.info("Starting frontend in traditional mode (after dynamo installation)") # Open log files for frontend @@ -246,5 +294,6 @@ def setup_aggregated_worker( rank=local_rank, profiler=profiler, dump_config_path=dump_config_path, + use_sglang_router=use_sglang_router, ) return run_command(cmd_to_run) diff --git a/src/srtctl/__init__.py b/src/srtctl/__init__.py index 935a4f53..9a6ec207 100644 --- a/src/srtctl/__init__.py +++ b/src/srtctl/__init__.py @@ -5,12 +5,10 @@ __version__ = "0.1.0" from .core.config import load_config, get_srtslurm_setting -from .backends.base import Backend -from .backends.sglang import SGLangBackend +from .core.backend import SGLangBackend __all__ = [ "load_config", "get_srtslurm_setting", - "Backend", "SGLangBackend", ] diff --git a/src/srtctl/backends/__init__.py b/src/srtctl/backends/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/srtctl/backends/base.py b/src/srtctl/backends/base.py deleted file mode 100644 index 0a273ee2..00000000 --- a/src/srtctl/backends/base.py +++ /dev/null @@ -1,88 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Base backend interface for inference frameworks. - -Defines a protocol for framework-specific implementations. -""" - -from abc import ABC, abstractmethod -from pathlib import Path - - -class Backend(ABC): - """Base class for inference backend implementations. - - Each backend is responsible for: - 1. Generating backend-specific config files - 2. Rendering commands with proper flags and environment variables - 3. Generating SLURM job scripts from Jinja templates - """ - - def __init__(self, config: dict): - """Initialize backend with user config. - - Args: - config: Full user configuration dict - """ - self.config = config - self.backend_config = config.get("backend", {}) - self.resources = config.get("resources", {}) - self.model = config.get("model", {}) - self.slurm = config.get("slurm", {}) - - @abstractmethod - def generate_config_file(self, params: dict = None) -> Path | None: - """Generate backend-specific config file. - - Args: - params: Optional sweep parameters for template expansion - - Returns: - Path to generated config file, or None if not applicable - """ - pass - - @abstractmethod - def render_command(self, mode: str, config_path: Path = None) -> str: - """Render full command that would be executed. - - Args: - mode: Worker mode (e.g., "prefill", "decode", "aggregated") - config_path: Path to generated config file (if applicable) - - Returns: - Multi-line bash command string with env vars and flags - """ - pass - - @abstractmethod - def generate_slurm_script(self, config_path: Path = None, timestamp: str = None) -> tuple[Path, str]: - """Generate SLURM job script from Jinja template. - - Args: - config_path: Path to backend-specific config file (if applicable) - timestamp: Timestamp for job submission (used in log directory naming) - - Returns: - Tuple of (script_path, rendered_script_content) - """ - pass - - def get_environment_vars(self, mode: str) -> dict[str, str]: - """Get environment variables for this mode. - - Args: - mode: Worker mode - - Returns: - Dict of environment variable key-value pairs - """ - env_key = f"{mode}_environment" - return self.backend_config.get(env_key, {}) - - def is_disaggregated(self) -> bool: - """Check if running in disaggregated mode (has prefill/decode nodes).""" - return self.resources.get("prefill_nodes") is not None diff --git a/src/srtctl/backends/sglang.py b/src/srtctl/backends/sglang.py deleted file mode 100644 index 5538c439..00000000 --- a/src/srtctl/backends/sglang.py +++ /dev/null @@ -1,400 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -SGLang backend support. -""" - -import logging -import os -import tempfile -import yaml -from datetime import datetime -from jinja2 import Template -from pathlib import Path - -import srtctl -from srtctl.core.config import get_srtslurm_setting -from srtctl.core.sweep import expand_template - -from .base import Backend - - -class SGLangBackend(Backend): - """SGLang backend for distributed serving.""" - - def __init__(self, config: dict, setup_script: str = None): - """Initialize SGLang backend. - - Args: - config: Full user configuration dict - setup_script: Optional custom setup script name in configs directory - """ - super().__init__(config) - self.setup_script = setup_script - - def generate_config_file(self, params: dict = None) -> Path | None: - """Generate SGLang YAML config file. - - Args: - params: Optional sweep parameters for template expansion - - Returns: - Path to generated config file - """ - if "sglang_config" not in self.backend_config: - return None - - sglang_cfg = self.backend_config["sglang_config"] - - # Expand templates if sweeping - if params: - sglang_cfg = expand_template(sglang_cfg, params) - logging.info(f"Expanded config with params: {params}") - - # Validate that all keys use dashes, not underscores - for mode in ["prefill", "decode", "aggregated"]: - if mode in sglang_cfg and sglang_cfg[mode]: - for key in sglang_cfg[mode].keys(): - if "_" in key: - raise ValueError( - f"Invalid key '{key}' in sglang_config.{mode}: " - f"Keys must use dashes (kebab-case), not underscores. " - f"Use '{key.replace('_', '-')}' instead." - ) - - # Extract prefill, decode, and aggregated configs (no conversion needed - already using dashes) - result = {} - for mode in ["prefill", "decode", "aggregated"]: - if mode in sglang_cfg: - result[mode] = sglang_cfg[mode] - - # Add environment variables as top-level keys - for mode in ["prefill", "decode", "aggregated"]: - env_vars = self.get_environment_vars(mode) - if env_vars: - result[f"{mode}_environment"] = env_vars - - # Write to temp file - fd, temp_path = tempfile.mkstemp(suffix=".yaml", prefix="sglang_config_") - with os.fdopen(fd, "w") as f: - yaml.dump(result, f, default_flow_style=False) - - logging.info(f"Generated SGLang config: {temp_path}") - return Path(temp_path) - - def render_command(self, mode: str, config_path: Path = None) -> str: - """Render full SGLang command with all flags inlined. - - Args: - mode: "prefill" or "decode" - config_path: Path to generated SGLang config file - - Returns: - Multi-line bash command string - """ - lines = [] - - # Environment variables - env_vars = self.get_environment_vars(mode) or {} - for key, val in env_vars.items(): - lines.append(f"{key}={val} \\") - - # Python command - use sglang.launch_server when profiler != none, dynamo.sglang otherwise - profiling_type = (self.config.get("profiling") or {}).get("type") or "none" - nsys_prefix = "nsys profile -t cuda,nvtx --cuda-graph-trace=node -c cudaProfilerApi --capture-range-end stop --force-overwrite true" - if profiling_type == "nsys": - lines.append(f"{nsys_prefix} python3 -m sglang.launch_server \\") - elif profiling_type == "torch": - lines.append("python3 -m sglang.launch_server \\") - else: - lines.append("python3 -m dynamo.sglang \\") - - # Inline all SGLang flags from config file - if config_path: - with open(config_path) as f: - sglang_config = yaml.load(f, Loader=yaml.FullLoader) - - mode_config = sglang_config.get(mode, {}) - flag_lines = self._config_to_flags(mode_config) - lines.extend(flag_lines) - - # Add coordination flags - coord_flags = self._get_coordination_flags(mode) - lines.extend(coord_flags) - - return "\n".join(lines) - - def _config_to_flags(self, config: dict) -> list[str]: - """Convert config dict to CLI flags. - - Args: - config: SGLang config dict for this mode - - Returns: - List of flag strings with backslash continuations - """ - lines = [] - profiling_type = (self.config.get("profiling") or {}).get("type") or "none" - - for key, value in sorted(config.items()): - # Convert underscores to hyphens - flag_name = key.replace("_", "-") - - # Always pass disaggregation-mode so profiling runs in PD mode - - if isinstance(value, bool): - if value: - lines.append(f" --{flag_name} \\") - elif isinstance(value, list): - values_str = " ".join(str(v) for v in value) - lines.append(f" --{flag_name} {values_str} \\") - else: - lines.append(f" --{flag_name} {value} \\") - - return lines - - def _get_coordination_flags(self, mode: str) -> list[str]: - """Get multi-node coordination flags. - - Args: - mode: "prefill" or "decode" - - Returns: - List of coordination flag strings - """ - lines = [] - - # Determine nnodes based on mode - if self.is_disaggregated(): - nnodes = self.resources["prefill_nodes"] if mode == "prefill" else self.resources["decode_nodes"] - else: - nnodes = self.resources["agg_nodes"] - - # Coordination flags - lines.append(" --dist-init-addr $HOST_IP_MACHINE:$PORT \\") - lines.append(f" --nnodes {nnodes} \\") - lines.append(" --node-rank $RANK \\") - - return lines - - def _get_enable_config_dump(self) -> bool: - """Get enable_config_dump value, handling profiling mode. - - Returns: - True if config dump should be enabled, False otherwise - """ - # Get value from config (defaults to True in schema) - enable_config_dump = self.config.get("enable_config_dump", True) - - # Auto-disable when profiling is enabled (unless explicitly set to True) - profiling_type = (self.config.get("profiling") or {}).get("type") or "none" - if profiling_type != "none": - # When profiling, disable config dump by default - # User can explicitly set enable_config_dump: true to override - return False - - return enable_config_dump - - def generate_slurm_script(self, config_path: Path = None, timestamp: str = None) -> tuple[Path, str]: - """Generate SLURM job script from Jinja template. - - Args: - config_path: Path to SGLang config file - timestamp: Timestamp for job submission - - Returns: - Tuple of (script_path, rendered_script_content) - """ - if timestamp is None: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - - # Determine mode and node counts - is_aggregated = not self.is_disaggregated() - - if is_aggregated: - agg_nodes = self.resources["agg_nodes"] - agg_workers = self.resources["agg_workers"] - prefill_nodes = 0 - decode_nodes = 0 - prefill_workers = 0 - decode_workers = 0 - total_nodes = agg_nodes - else: - prefill_nodes = self.resources["prefill_nodes"] - decode_nodes = self.resources["decode_nodes"] - prefill_workers = self.resources["prefill_workers"] - decode_workers = self.resources["decode_workers"] - agg_nodes = 0 - agg_workers = 0 - total_nodes = prefill_nodes + decode_nodes - - # Get SLURM settings - job_name = self.config.get("name", "srtctl-job") - account = self.slurm.get("account") or get_srtslurm_setting("default_account") - partition = self.slurm.get("partition") or get_srtslurm_setting("default_partition") - time_limit = self.slurm.get("time_limit") or get_srtslurm_setting("default_time_limit", "04:00:00") - - # Get resource settings from srtslurm.yaml if available - gpus_per_node = get_srtslurm_setting("gpus_per_node", self.resources.get("gpus_per_node")) - network_interface = get_srtslurm_setting("network_interface", None) - - # Get backend settings - gpu_type = self.backend_config.get("gpu_type", "h100") - - # Benchmark config - benchmark_config = self.config.get("benchmark", {}) - bench_type = benchmark_config.get("type", "manual") - do_benchmark = bench_type != "manual" - - # Parse benchmark args if applicable - parsable_config = "" - if bench_type == "sa-bench": - isl = benchmark_config.get("isl") - osl = benchmark_config.get("osl") - concurrencies = benchmark_config.get("concurrencies") - req_rate = benchmark_config.get("req_rate", "inf") - - if isinstance(concurrencies, list): - concurrency_str = "x".join(str(c) for c in concurrencies) - else: - concurrency_str = str(concurrencies) - - parsable_config = f"{isl} {osl} {concurrency_str} {req_rate}" - elif bench_type == "mmlu": - num_examples = benchmark_config.get("num_examples", 200) - max_tokens = benchmark_config.get("max_tokens", 2048) - repeat = benchmark_config.get("repeat", 8) - num_threads = benchmark_config.get("num_threads", 512) - parsable_config = f"{num_examples} {max_tokens} {repeat} {num_threads}" - elif bench_type == "gpqa": - num_examples = benchmark_config.get("num_examples", 198) - max_tokens = benchmark_config.get("max_tokens", 32768) - repeat = benchmark_config.get("repeat", 8) - num_threads = benchmark_config.get("num_threads", 128) - parsable_config = f"{num_examples} {max_tokens} {repeat} {num_threads}" - elif bench_type == "longbenchv2": - num_examples = benchmark_config.get("num_examples", None) - max_tokens = benchmark_config.get("max_tokens", 16384) - max_context_length = benchmark_config.get("max_context_length", 128000) - num_threads = benchmark_config.get("num_threads", 16) - categories = benchmark_config.get("categories", None) - parsable_config = f"{num_examples} {max_tokens} {max_context_length} {num_threads} {categories}" - - # Config directory should point to where deepep_config.json lives - # This is typically the configs/ directory in the yaml-config repo - yaml_config_root = Path(srtctl.__file__).parent.parent.parent - - # Log directory - check srtslurm.yaml first, then fall back to default - srtctl_root_setting = get_srtslurm_setting("srtctl_root") - if srtctl_root_setting: - srtctl_root = Path(srtctl_root_setting) - else: - # Fall back to default: current yaml-config directory (which contains scripts/) - srtctl_root = yaml_config_root - - # Use srtctl_root for config_dir_path so it respects srtslurm.yaml setting - config_dir_path = srtctl_root / "configs" - log_dir_path = srtctl_root / "logs" - - # Build profiling env injections - profiling_cfg = self.config.get("profiling") or {} - - def build_env_str(cfg: dict) -> str: - parts: list[str] = [] - if "isl" in cfg and cfg["isl"] is not None: - parts.append(f"PROFILE_ISL={cfg['isl']}") - if "osl" in cfg and cfg["osl"] is not None: - parts.append(f"PROFILE_OSL={cfg['osl']}") - if "concurrency" in cfg and cfg["concurrency"] is not None: - parts.append(f"PROFILE_CONCURRENCY={cfg['concurrency']}") - if "start_step" in cfg and cfg["start_step"] is not None: - parts.append(f"PROFILE_START_STEP={cfg['start_step']}") - if "stop_step" in cfg and cfg["stop_step"] is not None: - parts.append(f"PROFILE_STOP_STEP={cfg['stop_step']}") - return " ".join(parts) - - # Use the same profiling spec for both prefill and decode; in PD - # disaggregation mode this single spec drives both sides. - prefill_profile_env = build_env_str(profiling_cfg) - decode_profile_env = build_env_str(profiling_cfg) - - profiler_mode = profiling_cfg.get("type") or "none" - # Template variables - template_vars = { - "job_name": job_name, - "total_nodes": total_nodes, - "account": account, - "time_limit": time_limit, - "prefill_nodes": prefill_nodes, - "decode_nodes": decode_nodes, - "prefill_workers": prefill_workers, - "decode_workers": decode_workers, - "agg_nodes": agg_nodes, - "agg_workers": agg_workers, - "is_aggregated": is_aggregated, - "model_dir": self.model.get("path"), - "config_dir": str(config_dir_path), - "container_image": self.model.get("container"), - "gpus_per_node": gpus_per_node, - "network_interface": network_interface, - "gpu_type": gpu_type, - "partition": partition, - "enable_multiple_frontends": self.backend_config.get("enable_multiple_frontends", True), - "num_additional_frontends": self.backend_config.get("num_additional_frontends", 9), - "use_sglang_router": self.backend_config.get("use_sglang_router", False), - "do_benchmark": do_benchmark, - "benchmark_type": bench_type, - "benchmark_arg": parsable_config, - "timestamp": timestamp, - # Config dump enabled by default (True in schema) - # Auto-disabled when profiling unless explicitly enabled - "enable_config_dump": self._get_enable_config_dump(), - "log_dir_prefix": str(log_dir_path), # Absolute path to logs directory - "profiler": profiler_mode, - "prefill_profile_env": prefill_profile_env, - "decode_profile_env": decode_profile_env, - "setup_script": self.setup_script, - "use_gpus_per_node_directive": get_srtslurm_setting("use_gpus_per_node_directive", True), - "use_segment_sbatch_directive": get_srtslurm_setting("use_segment_sbatch_directive", True), - "extra_container_mounts": ",".join(self.config.get("extra_mount") or []), - } - - # Select template based on mode - if is_aggregated: - template_name = "job_script_template_agg.j2" - else: - template_name = "job_script_template_disagg.j2" - - # Find template path - check srtslurm.yaml first, then fall back to default location - srtctl_root = get_srtslurm_setting("srtctl_root") - - if srtctl_root: - # User specified srtctl_root in srtslurm.yaml - template_path = Path(srtctl_root) / "scripts" / "templates" / template_name - else: - # Fall back to default: current yaml-config directory (which contains scripts/) - yaml_config_root = Path(srtctl.__file__).parent.parent.parent - template_path = yaml_config_root / "scripts" / "templates" / template_name - - if not template_path.exists(): - raise FileNotFoundError( - f"Template not found: {template_path}\n" - f"Set 'srtctl_root' in srtslurm.yaml to point to your srtctl repo.\n" - f"Example: srtctl_root: /mnt/lustre01/users/slurm-shared/ishan/sweepr" - ) - - # Render template - with open(template_path) as f: - template = Template(f.read()) - - rendered_script = template.render(**template_vars) - - # Write to temporary file - fd, temp_path = tempfile.mkstemp(suffix=".sh", prefix="slurm_job_") - with os.fdopen(fd, "w") as f: - f.write(rendered_script) - - logging.info(f"Generated SLURM job script: {temp_path}") - return Path(temp_path), rendered_script diff --git a/src/srtctl/cli/submit.py b/src/srtctl/cli/submit.py index cb35c609..939905aa 100644 --- a/src/srtctl/cli/submit.py +++ b/src/srtctl/cli/submit.py @@ -26,7 +26,7 @@ # Import from srtctl modules from srtctl.core.config import load_config from srtctl.core.sweep import generate_sweep_configs -from srtctl.backends.sglang import SGLangBackend +from srtctl.core.backend import SGLangBackend def setup_logging(level: int = logging.INFO) -> None: @@ -38,113 +38,50 @@ def setup_logging(level: int = logging.INFO) -> None: def render_commands_file(backend, sglang_config_path: Path, output_path: Path) -> Path: - """Generate commands.sh file with rendered SGLang commands. + """Generate commands.sh with rendered SGLang commands.""" + content = f"""#!/bin/bash +# Generated SGLang commands - Config: {sglang_config_path} - Args: - backend: SGLang backend instance - sglang_config_path: Path to sglang_config.yaml - output_path: Where to save commands.sh +# PREFILL +{backend.render_command(mode="prefill", config_path=sglang_config_path)} - Returns: - Path to generated commands.sh - """ - content = "#!/bin/bash\n" - content += "# Generated SGLang commands\n" - content += f"# Config: {sglang_config_path}\n\n" - content += "# ============================================================\n" - content += "# PREFILL WORKER COMMAND\n" - content += "# ============================================================\n\n" - content += backend.render_command(mode="prefill", config_path=sglang_config_path) - content += "\n\n" - content += "# ============================================================\n" - content += "# DECODE WORKER COMMAND\n" - content += "# ============================================================\n\n" - content += backend.render_command(mode="decode", config_path=sglang_config_path) - content += "\n" - - with open(output_path, "w") as f: - f.write(content) +# DECODE +{backend.render_command(mode="decode", config_path=sglang_config_path)} +""" + output_path.write_text(content) output_path.chmod(0o755) - return output_path -class DryRunContext: - """Context for dry-run mode - creates output directory and saves artifacts""" - - def __init__(self, config: dict, job_name: str = None): - self.config = config - self.job_name = job_name or config.get("name", "dry-run") - self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - self.output_dir = None - self.sglang_config_path = None - - def setup(self) -> Path: - """Create dry-run output directory""" - # Create in dry-runs/ - base_dir = Path.cwd() / "dry-runs" - self.output_dir = base_dir / f"{self.job_name}_{self.timestamp}" - self.output_dir.mkdir(parents=True, exist_ok=True) - - logging.info(f"šŸ“ Dry-run output directory: {self.output_dir}") - return self.output_dir - - def save_config(self, config: dict) -> Path: - """Save resolved config (with all defaults applied)""" - config_path = self.output_dir / "config.yaml" - with open(config_path, "w") as f: - yaml.dump(config, f, default_flow_style=False, sort_keys=False) - logging.info(f" āœ“ Saved resolved config: {config_path.name}") - return config_path - - def save_sglang_config(self, sglang_config_path: Path) -> Path: - """Copy SGLang config to dry-run dir""" - if sglang_config_path and sglang_config_path.exists(): - dest = self.output_dir / "sglang_config.yaml" - shutil.copy(sglang_config_path, dest) - logging.info(f" āœ“ Saved SGLang config: {dest.name}") - self.sglang_config_path = dest - return dest - return None - - def save_rendered_commands(self, backend, sglang_config_path: Path) -> Path: - """Save just the rendered commands (no sbatch headers)""" - commands_path = self.output_dir / "commands.sh" - render_commands_file(backend, sglang_config_path, commands_path) - logging.info(f" āœ“ Saved rendered commands: {commands_path.name}") - return commands_path - - def save_metadata(self, config: dict) -> Path: - """Save submission metadata""" - metadata = { - "job_name": self.job_name, - "timestamp": self.timestamp, - "config": config, - "mode": "dry-run", - } - - metadata_path = self.output_dir / "metadata.json" - with open(metadata_path, "w") as f: - json.dump(metadata, f, indent=2) - logging.info(f" āœ“ Saved metadata: {metadata_path.name}") - return metadata_path - - def print_summary(self): - """Print summary of what would be submitted""" - print("\n" + "=" * 60) - print("šŸ” DRY-RUN SUMMARY") - print("=" * 60) - print(f"\nJob Name: {self.job_name}") - print(f"Output Directory: {self.output_dir}") - print("\nGenerated Files:") - print(" - config.yaml (resolved config with defaults)") - if self.sglang_config_path: - print(" - sglang_config.yaml (SGLang flags)") - print(" - commands.sh (full bash commands)") - print(" - metadata.json (submission info)") - print("\nTo see what commands would run:") - print(f" cat {self.output_dir}/commands.sh") - print("\n" + "=" * 60 + "\n") +def run_dry_run(config: dict, backend, sglang_config_path: Path = None) -> Path: + """Execute dry-run: save artifacts and print summary.""" + job_name = config.get("name", "dry-run") + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path.cwd() / "dry-runs" / f"{job_name}_{timestamp}" + output_dir.mkdir(parents=True, exist_ok=True) + + # Save config + with open(output_dir / "config.yaml", "w") as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + # Save sglang config if present + has_sglang = False + if sglang_config_path and sglang_config_path.exists(): + shutil.copy(sglang_config_path, output_dir / "sglang_config.yaml") + render_commands_file(backend, sglang_config_path, output_dir / "commands.sh") + has_sglang = True + + # Save metadata + with open(output_dir / "metadata.json", "w") as f: + json.dump({"job_name": job_name, "timestamp": timestamp, "mode": "dry-run"}, f, indent=2) + + # Print summary + print(f"\n{'=' * 60}\nšŸ” DRY-RUN: {job_name}\n{'=' * 60}") + print(f"Output: {output_dir}") + print(f"Files: config.yaml{', sglang_config.yaml, commands.sh' if has_sglang else ''}, metadata.json") + print(f"{'=' * 60}\n") + + return output_dir def submit_single( @@ -172,31 +109,10 @@ def submit_single( # Dry-run mode if dry_run: logging.info(f"šŸ” DRY-RUN MODE: {config['name']}") - ctx = DryRunContext(config) - ctx.setup() - - # Save user config - ctx.save_config(config) - - # Create backend instance backend_type = config.get("backend", {}).get("type") - if backend_type == "sglang": - backend = SGLangBackend(config, setup_script=setup_script) - sglang_config_path = backend.generate_config_file() - ctx.save_sglang_config(sglang_config_path) - - # Save rendered commands - if sglang_config_path: - ctx.save_rendered_commands(backend, sglang_config_path) - else: - sglang_config_path = None - - # Save metadata - ctx.save_metadata(config) - - # Print summary - ctx.print_summary() - + backend = SGLangBackend(config, setup_script=setup_script) if backend_type == "sglang" else None + sglang_config_path = backend.generate_config_file() if backend else None + run_dry_run(config, backend, sglang_config_path) return # Real submission mode @@ -259,73 +175,49 @@ def submit_single( shutil.copy(sglang_config_path, log_dir / "sglang_config.yaml") # Generate jobid.json metadata - - resources = config.get("resources", {}) - backend_cfg = config.get("backend", {}) - model = config.get("model", {}) - slurm_cfg = config.get("slurm", {}) + resources, model, slurm_cfg = config.get("resources", {}), config.get("model", {}), config.get("slurm", {}) benchmark_cfg = config.get("benchmark", {}) - metadata = { - "version": "1.0", - "generated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "run_metadata": { - "slurm_job_id": job_id, - "run_date": timestamp, - "job_name": config.get("name", "unnamed"), - "account": slurm_cfg.get("account"), - "partition": slurm_cfg.get("partition"), - "time_limit": slurm_cfg.get("time_limit"), - "container": model.get("container"), - "model_dir": model.get("path"), - "gpus_per_node": resources.get("gpus_per_node"), - "gpu_type": backend_cfg.get("gpu_type"), - "mode": "aggregated" if is_aggregated else "disaggregated", - }, + run_meta = { + "slurm_job_id": job_id, + "run_date": timestamp, + "job_name": config.get("name", "unnamed"), + "account": slurm_cfg.get("account"), + "partition": slurm_cfg.get("partition"), + "time_limit": slurm_cfg.get("time_limit"), + "container": model.get("container"), + "model_dir": model.get("path"), + "gpus_per_node": resources.get("gpus_per_node"), + "gpu_type": config.get("backend", {}).get("gpu_type"), + "mode": "aggregated" if is_aggregated else "disaggregated", } - - # Add mode-specific metadata if is_aggregated: - metadata["run_metadata"].update( - { - "agg_nodes": resources.get("agg_nodes"), - "agg_workers": resources.get("agg_workers"), - } - ) + run_meta.update(agg_nodes=resources.get("agg_nodes"), agg_workers=resources.get("agg_workers")) else: - metadata["run_metadata"].update( - { - "prefill_nodes": resources.get("prefill_nodes"), - "decode_nodes": resources.get("decode_nodes"), - "prefill_workers": resources.get("prefill_workers"), - "decode_workers": resources.get("decode_workers"), - } + run_meta.update( + prefill_nodes=resources.get("prefill_nodes"), + decode_nodes=resources.get("decode_nodes"), + prefill_workers=resources.get("prefill_workers"), + decode_workers=resources.get("decode_workers"), ) - # Add benchmark metadata if present - if benchmark_cfg: - bench_type = benchmark_cfg.get("type", "manual") - profiler_metadata = {"type": bench_type} + metadata = { + "version": "1.0", + "generated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "run_metadata": run_meta, + } + if bench_type := benchmark_cfg.get("type", "manual"): + bench_meta = {"type": bench_type} if bench_type == "sa-bench": - concurrencies = benchmark_cfg.get("concurrencies", []) - # Handle both list and string formats - if isinstance(concurrencies, list): - concurrency_str = "x".join(str(c) for c in concurrencies) if concurrencies else "" - else: - concurrency_str = str(concurrencies) if concurrencies else "" - profiler_metadata.update( - { - "isl": str(benchmark_cfg.get("isl", "")), - "osl": str(benchmark_cfg.get("osl", "")), - "concurrencies": concurrency_str, - "req-rate": str(benchmark_cfg.get("req_rate", "inf")), - } + conc = benchmark_cfg.get("concurrencies", []) + bench_meta.update( + isl=str(benchmark_cfg.get("isl", "")), + osl=str(benchmark_cfg.get("osl", "")), + concurrencies="x".join(str(c) for c in conc) if isinstance(conc, list) else str(conc or ""), + **{"req-rate": str(benchmark_cfg.get("req_rate", "inf"))}, ) - - metadata["profiler_metadata"] = profiler_metadata - - # Add tags if provided + metadata["profiler_metadata"] = bench_meta if tags: metadata["tags"] = tags @@ -366,7 +258,7 @@ def submit_sweep(config_path: Path, dry_run: bool = False, setup_script: str = N """ # Load YAML directly without validation (sweep configs have extra 'sweep' field) with open(config_path) as f: - sweep_config = yaml.load(f, Loader=yaml.FullLoader) + sweep_config = yaml.safe_load(f) # Generate all configs configs = generate_sweep_configs(sweep_config) @@ -408,20 +300,11 @@ def submit_sweep(config_path: Path, dry_run: bool = False, setup_script: str = N # Save rendered commands (like single dry-run does) render_commands_file(backend, sglang_config_path, job_dir / "commands.sh") - logging.info(f" āœ“ Saved to: {job_dir.name}") - - print("\n" + "=" * 60) - print("šŸ” SWEEP DRY-RUN SUMMARY") - print("=" * 60) - print(f"\nSweep: {sweep_config['name']}") - print(f"Jobs: {len(configs)}") - print(f"Output: {sweep_dir}") - print("\nEach job directory contains:") - print(" - config.yaml (expanded config)") - print(" - sglang_config.yaml (SGLang flags)") - print(" - commands.sh (full bash commands)") - print("\n" + "=" * 60 + "\n") + logging.info(f" āœ“ {job_dir.name}") + print( + f"\n{'=' * 60}\nšŸ” SWEEP: {sweep_config['name']} ({len(configs)} jobs)\nOutput: {sweep_dir}\n{'=' * 60}\n" + ) return # Real submission @@ -435,107 +318,41 @@ def main(): setup_logging() parser = argparse.ArgumentParser( - description="Unified job submission for srtctl", + description="srtctl - SLURM job submission", + epilog="Examples:\n srtctl apply -f config.yaml\n srtctl dry-run -f sweep.yaml --sweep", formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Submit from YAML config - srtctl apply -f config.yaml - - # Submit sweep (auto-detected from config) - srtctl apply -f sweep.yaml - - # Submit with custom setup script - srtctl apply -f config.yaml --setup-script custom-setup.sh - - # Submit with tags - srtctl apply -f config.yaml --tags experiment,baseline,v2 - - # Dry-run (validate without submitting) - srtctl dry-run -f config.yaml - - # Validate alias - srtctl validate -f config.yaml - - # Force sweep mode (if auto-detection fails) - srtctl apply -f config.yaml --sweep - """, ) - # Subcommands - subparsers = parser.add_subparsers(dest="command", help="Command to run", required=True) + subparsers = parser.add_subparsers(dest="command", required=True) - # Apply command - apply_parser = subparsers.add_parser("apply", help="Submit job(s) to SLURM") - apply_parser.add_argument("-f", "--file", type=Path, required=True, dest="config", help="YAML config file") - apply_parser.add_argument( - "--sweep", - action="store_true", - help="Force sweep mode (usually auto-detected)", - ) - apply_parser.add_argument( - "--setup-script", - type=str, - default=None, - help="Custom setup script name in configs directory (e.g., 'custom-setup.sh')", - ) - apply_parser.add_argument( - "--tags", - type=str, - default=None, - help="Comma-separated tags to apply to the run (e.g., 'experiment,baseline,v2')", - ) + # Common args for both commands + def add_common_args(p): + p.add_argument("-f", "--file", type=Path, required=True, dest="config", help="YAML config file") + p.add_argument("--sweep", action="store_true", help="Force sweep mode") - # Dry-run command - dry_run_parser = subparsers.add_parser("dry-run", help="Validate and generate artifacts without submitting") - dry_run_parser.add_argument("-f", "--file", type=Path, required=True, dest="config", help="YAML config file") - dry_run_parser.add_argument( - "--sweep", - action="store_true", - help="Force sweep mode (usually auto-detected)", - ) + apply_parser = subparsers.add_parser("apply", help="Submit job(s) to SLURM") + add_common_args(apply_parser) + apply_parser.add_argument("--setup-script", type=str, help="Custom setup script in configs/") + apply_parser.add_argument("--tags", type=str, help="Comma-separated tags") - # Validate command (alias for dry-run) - validate_parser = subparsers.add_parser("validate", help="Alias for dry-run") - validate_parser.add_argument("-f", "--file", type=Path, required=True, dest="config", help="YAML config file") - validate_parser.add_argument( - "--sweep", - action="store_true", - help="Force sweep mode (usually auto-detected)", - ) + dry_run_parser = subparsers.add_parser("dry-run", help="Validate without submitting") + add_common_args(dry_run_parser) args = parser.parse_args() - - # Check config exists if not args.config.exists(): - logging.error(f"Config file not found: {args.config}") + logging.error(f"Config not found: {args.config}") sys.exit(1) - # Determine if dry-run mode - is_dry_run = args.command in ("dry-run", "validate") - - # Auto-detect sweep unless explicitly set - is_sweep = args.sweep - if not is_sweep: - try: - is_sweep = is_sweep_config(args.config) - if is_sweep: - logging.info("Auto-detected sweep config") - except Exception as e: - logging.warning(f"Could not auto-detect sweep mode: {e}") - - # Parse tags if provided - tags = None - if hasattr(args, "tags") and args.tags: - tags = [t.strip() for t in args.tags.split(",") if t.strip()] - if tags: - logging.info(f"šŸ·ļø Tags: {', '.join(tags)}") + is_dry_run = args.command == "dry-run" + is_sweep = args.sweep or is_sweep_config(args.config) + tags = [t.strip() for t in (getattr(args, "tags", "") or "").split(",") if t.strip()] or None try: + setup_script = getattr(args, "setup_script", None) if is_sweep: - submit_sweep(args.config, dry_run=is_dry_run, setup_script=getattr(args, "setup_script", None), tags=tags) + submit_sweep(args.config, dry_run=is_dry_run, setup_script=setup_script, tags=tags) else: - submit_single(args.config, dry_run=is_dry_run, setup_script=getattr(args, "setup_script", None), tags=tags) + submit_single(config_path=args.config, dry_run=is_dry_run, setup_script=setup_script, tags=tags) except Exception as e: logging.exception(f"Error: {e}") sys.exit(1) diff --git a/src/srtctl/core/backend.py b/src/srtctl/core/backend.py new file mode 100644 index 00000000..3419ed41 --- /dev/null +++ b/src/srtctl/core/backend.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""SGLang backend for SLURM job generation.""" + +import logging +import os +import tempfile +import yaml +from datetime import datetime +from jinja2 import Template +from pathlib import Path + +import srtctl +from srtctl.core.config import get_srtslurm_setting +from srtctl.core.sweep import expand_template + + +class SGLangBackend: + """SGLang backend for distributed serving.""" + + def __init__(self, config: dict, setup_script: str = None): + self.config = config + self.backend_config = config.get("backend", {}) + self.resources = config.get("resources", {}) + self.model = config.get("model", {}) + self.slurm = config.get("slurm", {}) + self.setup_script = setup_script + + def is_disaggregated(self) -> bool: + return self.resources.get("prefill_nodes") is not None + + def get_environment_vars(self, mode: str) -> dict[str, str]: + return self.backend_config.get(f"{mode}_environment", {}) + + def _profiling_type(self) -> str: + return (self.config.get("profiling") or {}).get("type") or "none" + + def _config_to_flags(self, config: dict) -> list[str]: + lines = [] + for key, value in sorted(config.items()): + flag = key.replace("_", "-") + if isinstance(value, bool): + if value: + lines.append(f" --{flag} \\") + elif isinstance(value, list): + lines.append(f" --{flag} {' '.join(str(v) for v in value)} \\") + else: + lines.append(f" --{flag} {value} \\") + return lines + + def generate_config_file(self, params: dict = None) -> Path | None: + """Generate SGLang YAML config file.""" + if "sglang_config" not in self.backend_config: + return None + + sglang_cfg = self.backend_config["sglang_config"] + if params: + sglang_cfg = expand_template(sglang_cfg, params) + logging.info(f"Expanded config with params: {params}") + + # Validate kebab-case keys + for mode in ["prefill", "decode", "aggregated"]: + if mode in sglang_cfg and sglang_cfg[mode]: + for key in sglang_cfg[mode].keys(): + if "_" in key: + raise ValueError(f"Invalid key '{key}': use '{key.replace('_', '-')}' (kebab-case)") + + result = {mode: sglang_cfg[mode] for mode in ["prefill", "decode", "aggregated"] if mode in sglang_cfg} + for mode in ["prefill", "decode", "aggregated"]: + if env := self.get_environment_vars(mode): + result[f"{mode}_environment"] = env + + fd, temp_path = tempfile.mkstemp(suffix=".yaml", prefix="sglang_config_") + with os.fdopen(fd, "w") as f: + yaml.dump(result, f, default_flow_style=False) + logging.info(f"Generated SGLang config: {temp_path}") + return Path(temp_path) + + def render_command(self, mode: str, config_path: Path = None) -> str: + """Render full SGLang command with all flags inlined.""" + lines = [f"{k}={v} \\" for k, v in (self.get_environment_vars(mode) or {}).items()] + + prof = self._profiling_type() + use_sglang = prof != "none" or self.backend_config.get("use_sglang_router", False) + if prof == "nsys": + lines.append( + "nsys profile -t cuda,nvtx --cuda-graph-trace=node -c cudaProfilerApi --capture-range-end stop --force-overwrite true python3 -m sglang.launch_server \\" + ) + elif use_sglang: + lines.append("python3 -m sglang.launch_server \\") + else: + lines.append("python3 -m dynamo.sglang \\") + + if config_path: + with open(config_path) as f: + sglang_config = yaml.safe_load(f) + lines.extend(self._config_to_flags(sglang_config.get(mode, {}))) + + nnodes = ( + (self.resources["prefill_nodes"] if mode == "prefill" else self.resources["decode_nodes"]) + if self.is_disaggregated() + else self.resources["agg_nodes"] + ) + lines.extend( + [ + " --dist-init-addr $HOST_IP_MACHINE:$PORT \\", + f" --nnodes {nnodes} \\", + " --node-rank $RANK \\", + ] + ) + return "\n".join(lines) + + def generate_slurm_script(self, config_path: Path = None, timestamp: str = None) -> tuple[Path, str]: + """Generate SLURM job script from Jinja template.""" + timestamp = timestamp or datetime.now().strftime("%Y%m%d_%H%M%S") + is_aggregated = not self.is_disaggregated() + + if is_aggregated: + agg_nodes, agg_workers = self.resources["agg_nodes"], self.resources["agg_workers"] + prefill_nodes = decode_nodes = prefill_workers = decode_workers = 0 + total_nodes = agg_nodes + else: + prefill_nodes, decode_nodes = self.resources["prefill_nodes"], self.resources["decode_nodes"] + prefill_workers, decode_workers = self.resources["prefill_workers"], self.resources["decode_workers"] + agg_nodes = agg_workers = 0 + total_nodes = prefill_nodes + decode_nodes + + # SLURM settings + job_name = self.config.get("name", "srtctl-job") + account = self.slurm.get("account") or get_srtslurm_setting("default_account") + partition = self.slurm.get("partition") or get_srtslurm_setting("default_partition") + time_limit = self.slurm.get("time_limit") or get_srtslurm_setting("default_time_limit", "04:00:00") + gpus_per_node = get_srtslurm_setting("gpus_per_node", self.resources.get("gpus_per_node")) + + # Benchmark config + benchmark_config = self.config.get("benchmark", {}) + bench_type = benchmark_config.get("type", "manual") + parsable_config = "" + if bench_type == "sa-bench": + conc = benchmark_config.get("concurrencies") + conc_str = "x".join(str(c) for c in conc) if isinstance(conc, list) else str(conc) + parsable_config = f"{benchmark_config.get('isl')} {benchmark_config.get('osl')} {conc_str} {benchmark_config.get('req_rate', 'inf')}" + + # Paths + srtctl_root = Path(get_srtslurm_setting("srtctl_root") or Path(srtctl.__file__).parent.parent.parent) + + # Profiling env + profiling_cfg = self.config.get("profiling") or {} + env_map = { + "isl": "PROFILE_ISL", + "osl": "PROFILE_OSL", + "concurrency": "PROFILE_CONCURRENCY", + "start_step": "PROFILE_START_STEP", + "stop_step": "PROFILE_STOP_STEP", + } + profile_env = " ".join( + f"{env}={profiling_cfg[k]}" for k, env in env_map.items() if profiling_cfg.get(k) is not None + ) + profiler_mode = self._profiling_type() + + template_vars = { + "job_name": job_name, + "total_nodes": total_nodes, + "account": account, + "time_limit": time_limit, + "prefill_nodes": prefill_nodes, + "decode_nodes": decode_nodes, + "prefill_workers": prefill_workers, + "decode_workers": decode_workers, + "agg_nodes": agg_nodes, + "agg_workers": agg_workers, + "is_aggregated": is_aggregated, + "model_dir": self.model.get("path"), + "config_dir": str(srtctl_root / "configs"), + "container_image": self.model.get("container"), + "gpus_per_node": gpus_per_node, + "network_interface": get_srtslurm_setting("network_interface"), + "gpu_type": self.backend_config.get("gpu_type", "h100"), + "partition": partition, + "enable_multiple_frontends": self.backend_config.get("enable_multiple_frontends", True), + "num_additional_frontends": self.backend_config.get("num_additional_frontends", 9), + "use_sglang_router": self.backend_config.get("use_sglang_router", False), + "do_benchmark": bench_type != "manual", + "benchmark_type": bench_type, + "benchmark_arg": parsable_config, + "timestamp": timestamp, + "enable_config_dump": profiler_mode == "none" and self.config.get("enable_config_dump", True), + "log_dir_prefix": str(srtctl_root / "logs"), + "profiler": profiler_mode, + "prefill_profile_env": profile_env, + "decode_profile_env": profile_env, + "setup_script": self.setup_script, + "use_gpus_per_node_directive": get_srtslurm_setting("use_gpus_per_node_directive", True), + "use_segment_sbatch_directive": get_srtslurm_setting("use_segment_sbatch_directive", True), + "extra_container_mounts": ",".join(self.config.get("extra_mount") or []), + } + + template_name = "job_script_template_agg.j2" if is_aggregated else "job_script_template_disagg.j2" + template_path = srtctl_root / "scripts" / "templates" / template_name + if not template_path.exists(): + raise FileNotFoundError(f"Template not found: {template_path}\nSet 'srtctl_root' in srtslurm.yaml") + + with open(template_path) as f: + rendered_script = Template(f.read()).render(**template_vars) + + fd, temp_path = tempfile.mkstemp(suffix=".sh", prefix="slurm_job_") + with os.fdopen(fd, "w") as f: + f.write(rendered_script) + logging.info(f"Generated SLURM job script: {temp_path}") + return Path(temp_path), rendered_script diff --git a/src/srtctl/core/schema.py b/src/srtctl/core/schema.py index 66b97cbe..34c4d078 100644 --- a/src/srtctl/core/schema.py +++ b/src/srtctl/core/schema.py @@ -248,6 +248,8 @@ class BackendConfig(BaseModel): # Frontend / router settings enable_multiple_frontends: bool = True + # Number of additional frontends/routers beyond the first (total = 1 + num_additional_frontends) + # Used for both dynamo frontends and sglang-router instances num_additional_frontends: int = 9 # Whether to launch sglang_router alongside the workers (PD disaggregation). # This is user-configurable via backend.use_sglang_router in the recipe.