Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions examples/disaggregated/slurm/benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ This directory contains scripts to run disaggregated inference benchmarks using

## Overview

The benchmarking process is orchestrated through a set of shell scripts and a Python script that work together:
The benchmarking process is orchestrated through a set of shell scripts and Python scripts that work together:

1. `submit.sh`: The main entry point for submitting benchmark jobs to SLURM. It runs a parameter sweep by calling `sbatch` with different configurations.
2. `disaggr_torch.slurm`: The SLURM script that sets up and runs a single benchmark experiment. It launches a container, generates a configuration file, starts the server and workers, and runs the benchmark client.
3. `gen_yaml.py`: A Python script that generates the `config.yaml` file needed by `trtllm-serve`. It determines the server and worker configuration based on SLURM environment variables and script arguments.
4. `start_worker.sh`: A shell script responsible for starting a `trtllm-serve disaggregated_mpi_worker` on each allocated machine.
5. `run_benchmark.sh`: A shell script that waits for the server to be healthy and then runs the actual benchmark client (`run_benchmark.py`, not included in this directory).
2. `disaggr_torch.slurm`: The SLURM script that sets up and runs a single benchmark experiment. It launches a container, generates configuration files, starts the server and workers, and runs the benchmark client.
3. `gen_worker_config.py`: A Python script that generates the worker configuration YAML file needed by `trtllm-serve`. It determines the worker configuration based on SLURM environment variables and script arguments.
4. `gen_server_config.py`: A Python script that generates the server configuration YAML file needed by `trtllm-serve`. It determines the server configuration based on the number of context and generation servers.
5. `start_worker.sh`: A shell script responsible for starting disaggregated workers using `trtllm-serve` on each allocated machine.
6. `start_server.sh`: A shell script responsible for starting disaggregated server using `trtllm-serve` on each allocated machine.
7. `run_benchmark.sh`: A shell script that waits for the server to be healthy and then runs the actual benchmark client (`run_benchmark.py`, not included in this directory).

## File Descriptions

Expand Down Expand Up @@ -58,28 +60,52 @@ It takes the following arguments in order:
24. `model_dir`: Model directory path.
25. `trtllm_repo`: TensorRT-LLM repository path.

### `gen_yaml.py`
### `gen_worker_config.py`

This Python script generates the `config.yaml` file that configures the `trtllm-serve` application. It reads SLURM environment variables (`SLURM_JOB_NODELIST`, `SLURM_TASKS_PER_NODE`) to distribute workers across nodes.
This Python script generates the worker configuration YAML file that configures the `trtllm-serve` workers. It creates separate configurations for context and generation workers with different tensor parallelism, batch sizes, and other parameters.

**Usage:**

The script is called from within `disaggr_torch.slurm`. It takes numerous arguments to define the model, parallelism, and server configurations.
The script is called from within `disaggr_torch.slurm`. It takes numerous arguments to define the model, parallelism, and worker configurations for both context and generation phases.

### `gen_server_config.py`

This Python script generates the server configuration YAML file that configures the `trtllm-serve` disaggregated server. It reads hostname information from the work directory and creates a configuration that specifies the URLs for context and generation servers.

**Usage:**

The script is called from within `start_server.sh`. It takes arguments for the number of context and generation servers and the work directory.

### `start_worker.sh`

This script starts a `trtllm-serve disaggregated_mpi_worker`. It is launched by `srun` from the `disaggr_torch.slurm` script on all allocated nodes.

**Arguments:**

1. `config_file`: Path to the `config.yaml` file.
2. `enable_pdl`: `true` or `false`.
3. `ctx_gpus`: Number of GPUs used for the context phase.
4. `work_dir`: (Optional) Directory to store nsys profiling output.
1. `worker_type`: Either "CTX" or "GEN" to specify the worker type.
2. `worker_index`: Index of the worker instance.
3. `model_dir`: Path to the model directory.
4. `worker_port`: Port for the worker to listen on.
5. `benchmark_mode`: Benchmark mode setting.
6. `concurrency`: Concurrency level.
7. `enable_pdl`: `true` or `false`.
8. `work_dir`: Work directory for logs and configuration.
9. `nsys_on`: Whether to enable nsys profiling.

### `start_server.sh`

This script starts the `trtllm-serve disaggregated` server. It first generates the server configuration using `gen_server_config.py`, then starts the server process.

**Arguments:**

1. `num_ctx_servers`: Number of context servers.
2. `num_gen_servers`: Number of generation servers.
3. `work_dir`: Work directory for logs and configuration.
4. `script_dir`: Directory containing the scripts.

### `run_benchmark.sh`

This script orchestrates the execution of the benchmark client. It waits for the `config.yaml` to be created and for the server's `/health` endpoint to respond, then it runs the benchmark.
This script orchestrates the execution of the benchmark client. It waits for the configuration files to be created and for the server's `/health` endpoint to respond, then it runs the benchmark.

**Arguments:**

Expand All @@ -97,9 +123,9 @@ This script orchestrates the execution of the benchmark client. It waits for the
2. The user runs `./submit.sh`.
3. `submit.sh` submits one or more jobs to SLURM by calling `sbatch disaggr_torch.slurm` with different parameters.
4. For each job, SLURM allocates resources and runs `disaggr_torch.slurm`.
5. `disaggr_torch.slurm` runs `gen_yaml.py` to create a `config.yaml`.
6. `disaggr_torch.slurm` uses `srun` to launch `start_worker.sh` on all nodes, starting the MPI workers.
7. `disaggr_torch.slurm` starts the main `trtllm-serve` process.
5. `disaggr_torch.slurm` runs `gen_worker_config.py` to create worker configuration files.
6. `disaggr_torch.slurm` uses `srun` to launch `start_worker.sh` on all nodes, starting the MPI workers for both context and generation phases.
7. `disaggr_torch.slurm` starts the main `trtllm-serve` process using `start_server.sh`, which generates the server configuration using `gen_server_config.py`.
8. `disaggr_torch.slurm` runs `run_benchmark.sh` which waits for the server to be ready.
9. `run_benchmark.sh` executes the benchmark for each concurrency level specified.
10. After the benchmark, `run_benchmark.sh` and `disaggr_torch.slurm` attempt to kill the server and worker processes.
Expand Down
133 changes: 87 additions & 46 deletions examples/disaggregated/slurm/benchmark/disaggr_torch.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
#SBATCH --job-name=${job_name} # add your job name here or specify in the sbatch command
#SBATCH --time=02:00:00

set -u
set -e
set -x

# Context servers arguments
num_ctx_servers=${1}
ctx_tp_size=${2}
Expand Down Expand Up @@ -42,7 +46,10 @@ mounts=${23}
workdir=${24}
model_dir=${25}
benchmark_mode=${26}
trtllm_repo=${27}
trtllm_repo=${27:-""}

# Get GPUs per node dynamically from SLURM
ntasks_per_node=${SLURM_NTASKS_PER_NODE:-4} # Default to 4 for GB200

echo "================= parameters ================="
echo "num_ctx_servers: ${num_ctx_servers}"
Expand Down Expand Up @@ -72,6 +79,7 @@ echo "workdir: ${workdir}"
echo "model_dir: ${model_dir}"
echo "benchmark_mode: ${benchmark_mode}"
echo "trtllm_repo: ${trtllm_repo}"
echo "ntasks_per_node: ${ntasks_per_node}"
echo "==========================================="


Expand All @@ -80,8 +88,8 @@ gen_max_seq_len=$((isl + osl))
ctx_gpu_frac=${ctx_gpu_memory_fraction}
cache_transceiver_max_num_tokens=8448

container_name=disaggr
logdir=${workdir}/benchmark-${isl}-${osl}
container_name=disaggregated_serving
logdir=${workdir}/slurm-${SLURM_JOB_ID}/benchmark-${isl}-${osl}
mkdir -p ${logdir}
full_logdir=${logdir}/ctx${num_ctx_servers}_gen${num_gen_servers}_dep${gen_tp_size}_batch${gen_batch_size}_eplb${eplb_num_slots}_mtp${mtp_size}

Expand All @@ -107,13 +115,14 @@ if [ "${benchmark_mode}" != "gen_only" ] && [ "${benchmark_mode}" != "e2e" ]; th
benchmark_mode="e2e"
fi

if [ -z "${TRT_LLM_GIT_COMMIT}" ]; then
if [ -z "${TRT_LLM_GIT_COMMIT:-}" ]; then
export TRT_LLM_GIT_COMMIT=$(git -C ${trtllm_repo} rev-parse --short HEAD 2>/dev/null || echo "unknown")
echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}"
fi

nsys_on=""
# nsys_on=${full_logdir} # Uncomment this line to enable Nsys profiling

# start the container
srun -l --container-image=${container_image} \
--container-name=${container_name} \
Expand All @@ -128,60 +137,92 @@ if [ -n "${trtllm_repo}" ]; then
bash -c "cd ${trtllm_repo} && echo 'Running install operation...' && pip install -e . " 2>&1 | tee ${full_logdir}/install.log
fi

# generate the yaml file
srun -l --container-name=${container_name} \
echo "Generating YAML file for workers."
srun -l -N 1 -n 1 \
--container-name=${container_name} \
--container-mounts=${mounts} \
--mpi=pmix --overlap \
python3 ${workdir}/gen_yaml.py --config ${full_logdir}/config.yaml \
--model ${model_dir} \
--num_ctx_servers ${num_ctx_servers} \
--ctx_tp_size ${ctx_tp_size} \
--ctx_pp_size ${ctx_pp_size} \
--ctx_batch_size ${ctx_batch_size} \
--ctx_max_num_tokens ${ctx_max_num_tokens} \
--ctx_max_seq_len ${ctx_max_seq_len} \
--ctx_free_gpu_memory_fraction ${ctx_gpu_frac} \
--cache_transceiver_max_num_tokens ${cache_transceiver_max_num_tokens} \
--num_gen_servers ${num_gen_servers} \
--gen_tp_size ${gen_tp_size} \
--gen_pp_size ${gen_pp_size} \
--gen_batch_size ${gen_batch_size} \
--gen_max_num_tokens ${gen_max_num_tokens} \
--gen_max_seq_len ${gen_max_seq_len} \
--gen_gpu_memory_fraction ${gen_gpu_memory_fraction} \
--eplb_num_slots ${eplb_num_slots} \
$(if [ "${gen_enable_attention_dp}" = "true" ]; then echo "--gen_enable_attention_dp"; fi) \
$(if [ "${ctx_enable_attention_dp}" = "true" ]; then echo "--ctx_enable_attention_dp"; fi) \
$(if [ "${mtp_size}" -gt 0 ]; then echo "--mtp_size ${mtp_size}"; fi)
python3 ${workdir}/gen_worker_config.py \
--work_dir ${full_logdir} \
--ctx_tp_size ${ctx_tp_size} \
--ctx_pp_size ${ctx_pp_size} \
--ctx_batch_size ${ctx_batch_size} \
--ctx_max_num_tokens ${ctx_max_num_tokens} \
--ctx_max_seq_len ${ctx_max_seq_len} \
--ctx_free_gpu_memory_fraction ${ctx_gpu_frac} \
--gen_tp_size ${gen_tp_size} \
--gen_pp_size ${gen_pp_size} \
--gen_batch_size ${gen_batch_size} \
--gen_max_num_tokens ${gen_max_num_tokens} \
--gen_max_seq_len ${gen_max_seq_len} \
--gen_gpu_memory_fraction ${gen_gpu_memory_fraction} \
--eplb_num_slots ${eplb_num_slots} \
--mtp_size ${mtp_size} \
--cache_transceiver_max_num_tokens ${cache_transceiver_max_num_tokens} \
$(if [ "${ctx_enable_attention_dp}" = "true" ]; then echo "--ctx_enable_attention_dp"; fi) \
$(if [ "${gen_enable_attention_dp}" = "true" ]; then echo "--gen_enable_attention_dp"; fi) \
2>&1 | tee ${full_logdir}/gen_worker_config.log

echo "YAML file generated."

hostname_value=$(grep '^hostname:' ${full_logdir}/config.yaml | awk -F': ' '{print $2}')
echo "server host name: $hostname_value"
ctx_nodes_num=$(((ctx_tp_size + ntasks_per_node - 1) / ntasks_per_node))
gen_nodes_num=$(((gen_tp_size + ntasks_per_node - 1) / ntasks_per_node))

all_nodes=($(scontrol show hostname $SLURM_NODELIST | sort))
total_nodes_num=${#all_nodes[@]}
echo "all_nodes: ${all_nodes[@]}, total_nodes_num: ${total_nodes_num}"

# start the workers
srun -l --container-name=${container_name} \
# get the node list for the gen workers
total_gen_nodes_num=$((gen_nodes_num * num_gen_servers))
gen_nodes=(${all_nodes[@]:0:${total_gen_nodes_num}})
echo "gen_nodes: ${gen_nodes[@]}, total_gen_nodes_num: ${total_gen_nodes_num}"

# get the node list for the ctx workers
total_ctx_nodes_num=$((ctx_nodes_num * num_ctx_servers))
ctx_nodes=(${all_nodes[@]:${total_gen_nodes_num}:${total_nodes_num}})
echo "ctx_nodes: ${ctx_nodes[@]}, total_ctx_nodes_num: ${total_ctx_nodes_num}"

rm -rf ${full_logdir}/hostnames

# start the gen workers
for i in $(seq 0 $((num_gen_servers - 1))); do
srun -l -N ${gen_nodes_num} \
--ntasks=${gen_tp_size} \
--ntasks-per-node=${ntasks_per_node} \
--container-image=${container_image} \
--container-name=${container_name} \
--container-mounts=${mounts} \
--mpi=pmix --overlap \
bash ${workdir}/start_worker.sh ${full_logdir}/config.yaml "${enable_pdl}" ${ctx_gpus} ${benchmark_mode} ${concurrency} ${nsys_on} &> ${full_logdir}/output_workers.log &
--mpi=pmix \
bash ${workdir}/start_worker.sh "GEN" ${i} ${model_dir} "8336" ${benchmark_mode} ${concurrency} ${enable_pdl} ${full_logdir} ${nsys_on} \
&> ${full_logdir}/output_gen_${i}.log &
done

# start the ctx workers
for i in $(seq 0 $((num_ctx_servers - 1))); do
srun -l -N ${ctx_nodes_num} \
--ntasks=${ctx_tp_size} \
--ntasks-per-node=${ntasks_per_node} \
--container-image=${container_image} \
--container-name=${container_name} \
--container-mounts=${mounts} \
--mpi=pmix \
bash ${workdir}/start_worker.sh "CTX" ${i} ${model_dir} "8336" ${benchmark_mode} ${concurrency} ${enable_pdl} ${full_logdir} ${nsys_on} \
&> ${full_logdir}/output_ctx_${i}.log &
done

# start the server
srun -l --container-name=${container_name} \
--container-mounts=${mounts} \
--mpi=pmix --overlap -N 1 -n 1 \
-w ${hostname_value} \
bash ${workdir}/start_server.sh ${full_logdir}/config.yaml &> ${full_logdir}/output_server.log &
--container-image=${container_image} \
--container-mounts=${mounts} \
--mpi=pmix --overlap -N 1 -n 1 \
bash ${workdir}/start_server.sh ${num_ctx_servers} ${num_gen_servers} ${full_logdir} ${workdir} \
&> ${full_logdir}/output_server.log &

# start benchmarking
srun -l --container-name=${container_name} \
--container-mounts=${mounts} \
--mpi=pmix --overlap -N 1 -n 1 \
bash ${workdir}/run_benchmark.sh ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency}" ${streaming} ${full_logdir} > ${full_logdir}/benchmark.log 2>&1
--container-mounts=${mounts} \
--mpi=pmix --overlap -N 1 -n 1 \
bash ${workdir}/run_benchmark.sh ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency}" ${streaming} ${full_logdir} \
&> ${full_logdir}/benchmark.log 2>&1

# try to kill the server and workers
srun -l --container-name=${container_name} \
--container-mounts=${mounts} \
--mpi=pmix --overlap \
kill -9 $(ps aux | grep '[t]rtllm-serve' | awk '{print $2}') >/dev/null 2>&1 || true
wait
scancel ${SLURM_JOB_ID}
90 changes: 90 additions & 0 deletions examples/disaggregated/slurm/benchmark/gen_server_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import argparse
import os
import socket
import time

import yaml

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num_ctx_servers",
type=int,
required=True,
help="Number of context servers")
parser.add_argument("--num_gen_servers",
type=int,
required=True,
help="Number of generation servers")
parser.add_argument("--work_dir",
type=str,
default="logs",
help="Work directory")
parser.add_argument("--worker_port",
type=int,
default=8336,
help="Worker port")
parser.add_argument("--server_port",
type=int,
default=8333,
help="Server port")
args = parser.parse_args()

# check if the work_dir exists
if not os.path.exists(args.work_dir):
raise ValueError(f"Work directory {args.work_dir} not found")

#check all of the hostnames in the hostnames folder exists, if not, sleep 10 seconds and check again
hostnames_folder = os.path.join(args.work_dir, "hostnames")
while not os.path.exists(hostnames_folder):
time.sleep(10)
print(f"Waiting for hostnames folder {hostnames_folder} to be found")
hostnames = os.listdir(hostnames_folder)
# check length of hostnames is equal to num_ctx_servers + num_gen_servers, if not, sleep 10 seconds and check again
while len(hostnames) != args.num_ctx_servers + args.num_gen_servers:
time.sleep(10)
hostnames = os.listdir(hostnames_folder)
print(
f"Waiting for hostnames to be found in {hostnames_folder}, current length: {len(hostnames)}, expected length: {args.num_ctx_servers + args.num_gen_servers}"
)
print(f"All hostnames found in {hostnames_folder}")

# get the ctx and gen hostnames from the hostnames file
ctx_hostnames = []
gen_hostnames = []
for hostname_file in hostnames:
hostname_file_path = os.path.join(hostnames_folder, hostname_file)
with open(hostname_file_path, 'r') as f:
actual_hostname = f.read().strip()
print(f"Hostname: {actual_hostname} in {hostname_file}")

if hostname_file.startswith("CTX"):
ctx_hostnames.append(actual_hostname)
elif hostname_file.startswith("GEN"):
gen_hostnames.append(actual_hostname)

print(f"ctx_hostnames: {ctx_hostnames}")
print(f"gen_hostnames: {gen_hostnames}")

# get current hostname from env
hostname = socket.gethostname()
print(f"Current hostname: {hostname}")

server_config = {
'hostname': hostname,
'port': args.server_port,
'backend': 'pytorch',
'context_servers': {
'num_instances': args.num_ctx_servers,
'urls': [f'{host}:{args.worker_port}' for host in ctx_hostnames]
},
'generation_servers': {
'num_instances': args.num_gen_servers,
'urls': [f'{host}:{args.worker_port}' for host in gen_hostnames]
}
}

with open(os.path.join(args.work_dir, "server_config.yaml"), "w") as f:
yaml.dump(server_config, f)
print(
f"Server config file {os.path.join(args.work_dir, 'server_config.yaml')} generated"
)
Loading