From 0eac701dcad06f4a27d309c8d771a456aa0ec92d Mon Sep 17 00:00:00 2001 From: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Date: Thu, 21 Aug 2025 02:49:01 -0700 Subject: [PATCH 1/2] Make disagg example compatible with recommended usage Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Update Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Update Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Fix Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Update Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Update Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Update documents Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Better error handling Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Fix Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Fix Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Fix Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Update Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Update Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Minor fix Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> --- .../disaggregated/slurm/benchmark/README.md | 58 ++- .../slurm/benchmark/disaggr_torch.slurm | 133 ++++-- .../slurm/benchmark/gen_server_config.py | 90 ++++ .../slurm/benchmark/gen_worker_config.py | 233 +++++++++++ .../disaggregated/slurm/benchmark/gen_yaml.py | 395 ------------------ .../slurm/benchmark/run_benchmark.sh | 33 +- .../slurm/benchmark/start_server.sh | 45 +- .../slurm/benchmark/start_worker.sh | 59 ++- 8 files changed, 530 insertions(+), 516 deletions(-) create mode 100644 examples/disaggregated/slurm/benchmark/gen_server_config.py create mode 100644 examples/disaggregated/slurm/benchmark/gen_worker_config.py delete mode 100644 examples/disaggregated/slurm/benchmark/gen_yaml.py diff --git a/examples/disaggregated/slurm/benchmark/README.md b/examples/disaggregated/slurm/benchmark/README.md index 7875d693ce1..aeee9a5dcd7 100644 --- a/examples/disaggregated/slurm/benchmark/README.md +++ b/examples/disaggregated/slurm/benchmark/README.md @@ -4,13 +4,15 @@ This directory contains scripts to run disaggregated inference benchmarks using ## Overview -The benchmarking process is orchestrated through a set of shell scripts and a Python script that work together: +The benchmarking process is orchestrated through a set of shell scripts and Python scripts that work together: 1. `submit.sh`: The main entry point for submitting benchmark jobs to SLURM. It runs a parameter sweep by calling `sbatch` with different configurations. -2. `disaggr_torch.slurm`: The SLURM script that sets up and runs a single benchmark experiment. It launches a container, generates a configuration file, starts the server and workers, and runs the benchmark client. -3. `gen_yaml.py`: A Python script that generates the `config.yaml` file needed by `trtllm-serve`. It determines the server and worker configuration based on SLURM environment variables and script arguments. -4. `start_worker.sh`: A shell script responsible for starting a `trtllm-serve disaggregated_mpi_worker` on each allocated machine. -5. `run_benchmark.sh`: A shell script that waits for the server to be healthy and then runs the actual benchmark client (`run_benchmark.py`, not included in this directory). +2. `disaggr_torch.slurm`: The SLURM script that sets up and runs a single benchmark experiment. It launches a container, generates configuration files, starts the server and workers, and runs the benchmark client. +3. `gen_worker_config.py`: A Python script that generates the worker configuration YAML file needed by `trtllm-serve`. It determines the worker configuration based on SLURM environment variables and script arguments. +4. `gen_server_config.py`: A Python script that generates the server configuration YAML file needed by `trtllm-serve`. It determines the server configuration based on the number of context and generation servers. +5. `start_worker.sh`: A shell script responsible for starting disaggregated workers using `trtllm-serve` on each allocated machine. +6. `start_server.sh`: A shell script responsible for starting disaggregated server using `trtllm-serve` on each allocated machine. +7. `run_benchmark.sh`: A shell script that waits for the server to be healthy and then runs the actual benchmark client (`run_benchmark.py`, not included in this directory). ## File Descriptions @@ -58,13 +60,21 @@ It takes the following arguments in order: 24. `model_dir`: Model directory path. 25. `trtllm_repo`: TensorRT-LLM repository path. -### `gen_yaml.py` +### `gen_worker_config.py` -This Python script generates the `config.yaml` file that configures the `trtllm-serve` application. It reads SLURM environment variables (`SLURM_JOB_NODELIST`, `SLURM_TASKS_PER_NODE`) to distribute workers across nodes. +This Python script generates the worker configuration YAML file that configures the `trtllm-serve` workers. It creates separate configurations for context and generation workers with different tensor parallelism, batch sizes, and other parameters. **Usage:** -The script is called from within `disaggr_torch.slurm`. It takes numerous arguments to define the model, parallelism, and server configurations. +The script is called from within `disaggr_torch.slurm`. It takes numerous arguments to define the model, parallelism, and worker configurations for both context and generation phases. + +### `gen_server_config.py` + +This Python script generates the server configuration YAML file that configures the `trtllm-serve` disaggregated server. It reads hostname information from the work directory and creates a configuration that specifies the URLs for context and generation servers. + +**Usage:** + +The script is called from within `start_server.sh`. It takes arguments for the number of context and generation servers and the work directory. ### `start_worker.sh` @@ -72,14 +82,30 @@ This script starts a `trtllm-serve disaggregated_mpi_worker`. It is launched by **Arguments:** -1. `config_file`: Path to the `config.yaml` file. -2. `enable_pdl`: `true` or `false`. -3. `ctx_gpus`: Number of GPUs used for the context phase. -4. `work_dir`: (Optional) Directory to store nsys profiling output. +1. `worker_type`: Either "CTX" or "GEN" to specify the worker type. +2. `worker_index`: Index of the worker instance. +3. `model_dir`: Path to the model directory. +4. `worker_port`: Port for the worker to listen on. +5. `benchmark_mode`: Benchmark mode setting. +6. `concurrency`: Concurrency level. +7. `enable_pdl`: `true` or `false`. +8. `work_dir`: Work directory for logs and configuration. +9. `nsys_on`: Whether to enable nsys profiling. + +### `start_server.sh` + +This script starts the `trtllm-serve disaggregated` server. It first generates the server configuration using `gen_server_config.py`, then starts the server process. + +**Arguments:** + +1. `num_ctx_servers`: Number of context servers. +2. `num_gen_servers`: Number of generation servers. +3. `work_dir`: Work directory for logs and configuration. +4. `script_dir`: Directory containing the scripts. ### `run_benchmark.sh` -This script orchestrates the execution of the benchmark client. It waits for the `config.yaml` to be created and for the server's `/health` endpoint to respond, then it runs the benchmark. +This script orchestrates the execution of the benchmark client. It waits for the configuration files to be created and for the server's `/health` endpoint to respond, then it runs the benchmark. **Arguments:** @@ -97,9 +123,9 @@ This script orchestrates the execution of the benchmark client. It waits for the 2. The user runs `./submit.sh`. 3. `submit.sh` submits one or more jobs to SLURM by calling `sbatch disaggr_torch.slurm` with different parameters. 4. For each job, SLURM allocates resources and runs `disaggr_torch.slurm`. -5. `disaggr_torch.slurm` runs `gen_yaml.py` to create a `config.yaml`. -6. `disaggr_torch.slurm` uses `srun` to launch `start_worker.sh` on all nodes, starting the MPI workers. -7. `disaggr_torch.slurm` starts the main `trtllm-serve` process. +5. `disaggr_torch.slurm` runs `gen_worker_config.py` to create worker configuration files. +6. `disaggr_torch.slurm` uses `srun` to launch `start_worker.sh` on all nodes, starting the MPI workers for both context and generation phases. +7. `disaggr_torch.slurm` starts the main `trtllm-serve` process using `start_server.sh`, which generates the server configuration using `gen_server_config.py`. 8. `disaggr_torch.slurm` runs `run_benchmark.sh` which waits for the server to be ready. 9. `run_benchmark.sh` executes the benchmark for each concurrency level specified. 10. After the benchmark, `run_benchmark.sh` and `disaggr_torch.slurm` attempt to kill the server and worker processes. diff --git a/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm index 377544ab23d..8d1f9257584 100644 --- a/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm +++ b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm @@ -7,6 +7,10 @@ #SBATCH --job-name=${job_name} # add your job name here or specify in the sbatch command #SBATCH --time=02:00:00 +set -u +set -e +set -x + # Context servers arguments num_ctx_servers=${1} ctx_tp_size=${2} @@ -42,7 +46,10 @@ mounts=${23} workdir=${24} model_dir=${25} benchmark_mode=${26} -trtllm_repo=${27} +trtllm_repo=${27:-""} + +# Get GPUs per node dynamically from SLURM +ntasks_per_node=${SLURM_NTASKS_PER_NODE:-4} # Default to 4 for GB200 echo "================= parameters =================" echo "num_ctx_servers: ${num_ctx_servers}" @@ -72,6 +79,7 @@ echo "workdir: ${workdir}" echo "model_dir: ${model_dir}" echo "benchmark_mode: ${benchmark_mode}" echo "trtllm_repo: ${trtllm_repo}" +echo "ntasks_per_node: ${ntasks_per_node}" echo "===========================================" @@ -80,8 +88,8 @@ gen_max_seq_len=$((isl + osl)) ctx_gpu_frac=${ctx_gpu_memory_fraction} cache_transceiver_max_num_tokens=8448 -container_name=disaggr -logdir=${workdir}/benchmark-${isl}-${osl} +container_name=disaggregated_serving +logdir=${workdir}/slurm-${SLURM_JOB_ID}/benchmark-${isl}-${osl} mkdir -p ${logdir} full_logdir=${logdir}/ctx${num_ctx_servers}_gen${num_gen_servers}_dep${gen_tp_size}_batch${gen_batch_size}_eplb${eplb_num_slots}_mtp${mtp_size} @@ -107,13 +115,14 @@ if [ "${benchmark_mode}" != "gen_only" ] && [ "${benchmark_mode}" != "e2e" ]; th benchmark_mode="e2e" fi -if [ -z "${TRT_LLM_GIT_COMMIT}" ]; then +if [ -z "${TRT_LLM_GIT_COMMIT:-}" ]; then export TRT_LLM_GIT_COMMIT=$(git -C ${trtllm_repo} rev-parse --short HEAD 2>/dev/null || echo "unknown") echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}" fi nsys_on="" # nsys_on=${full_logdir} # Uncomment this line to enable Nsys profiling + # start the container srun -l --container-image=${container_image} \ --container-name=${container_name} \ @@ -128,60 +137,92 @@ if [ -n "${trtllm_repo}" ]; then bash -c "cd ${trtllm_repo} && echo 'Running install operation...' && pip install -e . " 2>&1 | tee ${full_logdir}/install.log fi -# generate the yaml file -srun -l --container-name=${container_name} \ +echo "Generating YAML file for workers." +srun -l -N 1 -n 1 \ + --container-name=${container_name} \ --container-mounts=${mounts} \ --mpi=pmix --overlap \ - python3 ${workdir}/gen_yaml.py --config ${full_logdir}/config.yaml \ - --model ${model_dir} \ - --num_ctx_servers ${num_ctx_servers} \ - --ctx_tp_size ${ctx_tp_size} \ - --ctx_pp_size ${ctx_pp_size} \ - --ctx_batch_size ${ctx_batch_size} \ - --ctx_max_num_tokens ${ctx_max_num_tokens} \ - --ctx_max_seq_len ${ctx_max_seq_len} \ - --ctx_free_gpu_memory_fraction ${ctx_gpu_frac} \ - --cache_transceiver_max_num_tokens ${cache_transceiver_max_num_tokens} \ - --num_gen_servers ${num_gen_servers} \ - --gen_tp_size ${gen_tp_size} \ - --gen_pp_size ${gen_pp_size} \ - --gen_batch_size ${gen_batch_size} \ - --gen_max_num_tokens ${gen_max_num_tokens} \ - --gen_max_seq_len ${gen_max_seq_len} \ - --gen_gpu_memory_fraction ${gen_gpu_memory_fraction} \ - --eplb_num_slots ${eplb_num_slots} \ - $(if [ "${gen_enable_attention_dp}" = "true" ]; then echo "--gen_enable_attention_dp"; fi) \ - $(if [ "${ctx_enable_attention_dp}" = "true" ]; then echo "--ctx_enable_attention_dp"; fi) \ - $(if [ "${mtp_size}" -gt 0 ]; then echo "--mtp_size ${mtp_size}"; fi) + python3 ${workdir}/gen_worker_config.py \ + --work_dir ${full_logdir} \ + --ctx_tp_size ${ctx_tp_size} \ + --ctx_pp_size ${ctx_pp_size} \ + --ctx_batch_size ${ctx_batch_size} \ + --ctx_max_num_tokens ${ctx_max_num_tokens} \ + --ctx_max_seq_len ${ctx_max_seq_len} \ + --ctx_free_gpu_memory_fraction ${ctx_gpu_frac} \ + --gen_tp_size ${gen_tp_size} \ + --gen_pp_size ${gen_pp_size} \ + --gen_batch_size ${gen_batch_size} \ + --gen_max_num_tokens ${gen_max_num_tokens} \ + --gen_max_seq_len ${gen_max_seq_len} \ + --gen_gpu_memory_fraction ${gen_gpu_memory_fraction} \ + --eplb_num_slots ${eplb_num_slots} \ + --mtp_size ${mtp_size} \ + --cache_transceiver_max_num_tokens ${cache_transceiver_max_num_tokens} \ + $(if [ "${ctx_enable_attention_dp}" = "true" ]; then echo "--ctx_enable_attention_dp"; fi) \ + $(if [ "${gen_enable_attention_dp}" = "true" ]; then echo "--gen_enable_attention_dp"; fi) \ + 2>&1 | tee ${full_logdir}/gen_worker_config.log echo "YAML file generated." -hostname_value=$(grep '^hostname:' ${full_logdir}/config.yaml | awk -F': ' '{print $2}') -echo "server host name: $hostname_value" +ctx_nodes_num=$(((ctx_tp_size + ntasks_per_node - 1) / ntasks_per_node)) +gen_nodes_num=$(((gen_tp_size + ntasks_per_node - 1) / ntasks_per_node)) +all_nodes=($(scontrol show hostname $SLURM_NODELIST | sort)) +total_nodes_num=${#all_nodes[@]} +echo "all_nodes: ${all_nodes[@]}, total_nodes_num: ${total_nodes_num}" -# start the workers -srun -l --container-name=${container_name} \ +# get the node list for the gen workers +total_gen_nodes_num=$((gen_nodes_num * num_gen_servers)) +gen_nodes=(${all_nodes[@]:0:${total_gen_nodes_num}}) +echo "gen_nodes: ${gen_nodes[@]}, total_gen_nodes_num: ${total_gen_nodes_num}" + +# get the node list for the ctx workers +total_ctx_nodes_num=$((ctx_nodes_num * num_ctx_servers)) +ctx_nodes=(${all_nodes[@]:${total_gen_nodes_num}:${total_nodes_num}}) +echo "ctx_nodes: ${ctx_nodes[@]}, total_ctx_nodes_num: ${total_ctx_nodes_num}" + +rm -rf ${full_logdir}/hostnames + +# start the gen workers +for i in $(seq 0 $((num_gen_servers - 1))); do + srun -l -N ${gen_nodes_num} \ + --ntasks=${gen_tp_size} \ + --ntasks-per-node=${ntasks_per_node} \ + --container-image=${container_image} \ + --container-name=${container_name} \ --container-mounts=${mounts} \ - --mpi=pmix --overlap \ - bash ${workdir}/start_worker.sh ${full_logdir}/config.yaml "${enable_pdl}" ${ctx_gpus} ${benchmark_mode} ${concurrency} ${nsys_on} &> ${full_logdir}/output_workers.log & + --mpi=pmix \ + bash ${workdir}/start_worker.sh "GEN" ${i} ${model_dir} "8336" ${benchmark_mode} ${concurrency} ${enable_pdl} ${full_logdir} ${nsys_on} \ + &> ${full_logdir}/output_gen_${i}.log & +done + +# start the ctx workers +for i in $(seq 0 $((num_ctx_servers - 1))); do + srun -l -N ${ctx_nodes_num} \ + --ntasks=${ctx_tp_size} \ + --ntasks-per-node=${ntasks_per_node} \ + --container-image=${container_image} \ + --container-name=${container_name} \ + --container-mounts=${mounts} \ + --mpi=pmix \ + bash ${workdir}/start_worker.sh "CTX" ${i} ${model_dir} "8336" ${benchmark_mode} ${concurrency} ${enable_pdl} ${full_logdir} ${nsys_on} \ + &> ${full_logdir}/output_ctx_${i}.log & +done # start the server srun -l --container-name=${container_name} \ - --container-mounts=${mounts} \ - --mpi=pmix --overlap -N 1 -n 1 \ - -w ${hostname_value} \ - bash ${workdir}/start_server.sh ${full_logdir}/config.yaml &> ${full_logdir}/output_server.log & + --container-image=${container_image} \ + --container-mounts=${mounts} \ + --mpi=pmix --overlap -N 1 -n 1 \ + bash ${workdir}/start_server.sh ${num_ctx_servers} ${num_gen_servers} ${full_logdir} ${workdir} \ + &> ${full_logdir}/output_server.log & # start benchmarking srun -l --container-name=${container_name} \ - --container-mounts=${mounts} \ - --mpi=pmix --overlap -N 1 -n 1 \ - bash ${workdir}/run_benchmark.sh ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency}" ${streaming} ${full_logdir} > ${full_logdir}/benchmark.log 2>&1 + --container-mounts=${mounts} \ + --mpi=pmix --overlap -N 1 -n 1 \ + bash ${workdir}/run_benchmark.sh ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency}" ${streaming} ${full_logdir} \ + &> ${full_logdir}/benchmark.log 2>&1 -# try to kill the server and workers -srun -l --container-name=${container_name} \ - --container-mounts=${mounts} \ - --mpi=pmix --overlap \ - kill -9 $(ps aux | grep '[t]rtllm-serve' | awk '{print $2}') >/dev/null 2>&1 || true -wait +scancel ${SLURM_JOB_ID} diff --git a/examples/disaggregated/slurm/benchmark/gen_server_config.py b/examples/disaggregated/slurm/benchmark/gen_server_config.py new file mode 100644 index 00000000000..c427f5d42b4 --- /dev/null +++ b/examples/disaggregated/slurm/benchmark/gen_server_config.py @@ -0,0 +1,90 @@ +import argparse +import os +import socket +import time + +import yaml + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num_ctx_servers", + type=int, + required=True, + help="Number of context servers") + parser.add_argument("--num_gen_servers", + type=int, + required=True, + help="Number of generation servers") + parser.add_argument("--work_dir", + type=str, + default="logs", + help="Work directory") + parser.add_argument("--worker_port", + type=int, + default=8336, + help="Worker port") + parser.add_argument("--server_port", + type=int, + default=8333, + help="Server port") + args = parser.parse_args() + + # check if the work_dir exists + if not os.path.exists(args.work_dir): + raise ValueError(f"Work directory {args.work_dir} not found") + + #check all of the hostnames in the hostnames folder exists, if not, sleep 10 seconds and check again + hostnames_folder = os.path.join(args.work_dir, "hostnames") + while not os.path.exists(hostnames_folder): + time.sleep(10) + print(f"Waiting for hostnames folder {hostnames_folder} to be found") + hostnames = os.listdir(hostnames_folder) + # check length of hostnames is equal to num_ctx_servers + num_gen_servers, if not, sleep 10 seconds and check again + while len(hostnames) != args.num_ctx_servers + args.num_gen_servers: + time.sleep(10) + hostnames = os.listdir(hostnames_folder) + print( + f"Waiting for hostnames to be found in {hostnames_folder}, current length: {len(hostnames)}, expected length: {args.num_ctx_servers + args.num_gen_servers}" + ) + print(f"All hostnames found in {hostnames_folder}") + + # get the ctx and gen hostnames from the hostnames file + ctx_hostnames = [] + gen_hostnames = [] + for hostname_file in hostnames: + hostname_file_path = os.path.join(hostnames_folder, hostname_file) + with open(hostname_file_path, 'r') as f: + actual_hostname = f.read().strip() + print(f"Hostname: {actual_hostname} in {hostname_file}") + + if hostname_file.startswith("CTX"): + ctx_hostnames.append(actual_hostname) + elif hostname_file.startswith("GEN"): + gen_hostnames.append(actual_hostname) + + print(f"ctx_hostnames: {ctx_hostnames}") + print(f"gen_hostnames: {gen_hostnames}") + + # get current hostname from env + hostname = socket.gethostname() + print(f"Current hostname: {hostname}") + + server_config = { + 'hostname': hostname, + 'port': args.server_port, + 'backend': 'pytorch', + 'context_servers': { + 'num_instances': args.num_ctx_servers, + 'urls': [f'{host}:{args.worker_port}' for host in ctx_hostnames] + }, + 'generation_servers': { + 'num_instances': args.num_gen_servers, + 'urls': [f'{host}:{args.worker_port}' for host in gen_hostnames] + } + } + + with open(os.path.join(args.work_dir, "server_config.yaml"), "w") as f: + yaml.dump(server_config, f) + print( + f"Server config file {os.path.join(args.work_dir, 'server_config.yaml')} generated" + ) diff --git a/examples/disaggregated/slurm/benchmark/gen_worker_config.py b/examples/disaggregated/slurm/benchmark/gen_worker_config.py new file mode 100644 index 00000000000..adee5cbe72c --- /dev/null +++ b/examples/disaggregated/slurm/benchmark/gen_worker_config.py @@ -0,0 +1,233 @@ +import argparse +import os + +import yaml + + +def gen_config_file(work_dir: str, + ctx_tp_size: int, + ctx_pp_size: int, + ctx_batch_size: int, + ctx_max_num_tokens: int, + ctx_max_seq_len: int, + ctx_free_gpu_memory_fraction: float, + ctx_enable_attention_dp: bool, + gen_tp_size: int, + gen_pp_size: int, + gen_batch_size: int, + gen_max_num_tokens: int, + gen_max_seq_len: int, + gen_enable_attention_dp: bool, + gen_gpu_memory_fraction: float, + eplb_num_slots: int, + mtp_size: int = 0, + cache_transceiver_max_num_tokens: int = 4608) -> None: + """ + Generate configuration YAML file for disaggregated inference. + + Args: + config_path: Path to save the config file + model_path: Path to the model + num_ctx_servers: Number of context servers + ctx_tp_size: Tensor parallel size for context servers + ctx_pp_size: Pipeline parallel size for context servers + ctx_batch_size: Batch size for context servers + ctx_max_num_tokens: Max number of tokens for context servers + ctx_max_seq_len: Max sequence length for context servers + ctx_free_gpu_memory_fraction: Free GPU memory fraction for context servers + ctx_enable_attention_dp: Enable attention DP for context servers + num_gen_servers: Number of generation servers + gen_tp_size: Tensor parallel size for generation servers + gen_pp_size: Pipeline parallel size for generation servers + gen_batch_size: Batch size for generation servers + gen_max_num_tokens: Max number of tokens for generation servers + gen_enable_attention_dp: Enable attention DP for generation servers + gen_gpu_memory_fraction: GPU memory fraction for generation servers + eplb_num_slots: Number of slots for eplb + worker_start_port: Start port for workers + server_port: Server port + """ + ctx_config = { + 'max_batch_size': ctx_batch_size, + 'max_num_tokens': ctx_max_num_tokens, + 'max_seq_len': ctx_max_seq_len, + 'tensor_parallel_size': ctx_tp_size, + 'moe_expert_parallel_size': ctx_tp_size, + 'enable_attention_dp': True if ctx_enable_attention_dp else False, + 'pipeline_parallel_size': ctx_pp_size, + 'print_iter_log': True, + 'disable_overlap_scheduler': True, + 'kv_cache_config': { + 'enable_block_reuse': False, + 'free_gpu_memory_fraction': ctx_free_gpu_memory_fraction, + 'dtype': 'fp8', + }, + 'cache_transceiver_config': { + 'max_tokens_in_buffer': cache_transceiver_max_num_tokens, + 'backend': 'DEFAULT', + }, + } + + gen_cuda_graph_batch_sizes = [ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024, 2048, gen_batch_size + ] + + gen_moe_backend = "CUTLASS" + if gen_tp_size >= 16 and gen_enable_attention_dp: + gen_moe_backend = "WIDEEP" + if not gen_enable_attention_dp: + gen_moe_backend = "TRTLLM" + + gen_config = { + 'tensor_parallel_size': gen_tp_size, + 'moe_expert_parallel_size': gen_tp_size, + 'enable_attention_dp': True if gen_enable_attention_dp else False, + 'pipeline_parallel_size': gen_pp_size, + 'max_batch_size': gen_batch_size, + 'max_num_tokens': gen_max_num_tokens, + 'max_seq_len': gen_max_seq_len, + 'cuda_graph_config': { + 'enable_padding': True, + 'batch_sizes': gen_cuda_graph_batch_sizes, + }, + 'print_iter_log': True, + 'kv_cache_config': { + 'enable_block_reuse': False, + 'free_gpu_memory_fraction': gen_gpu_memory_fraction, + 'dtype': 'fp8', + }, + 'moe_config': { + 'backend': gen_moe_backend, + }, + 'cache_transceiver_config': { + 'max_tokens_in_buffer': cache_transceiver_max_num_tokens, + 'backend': 'DEFAULT', + }, + 'stream_interval': 20, + } + + if gen_tp_size == 8 and not gen_enable_attention_dp: + gen_config['allreduce_strategy'] = "MNNVL" + + if eplb_num_slots > 0: + moe_load_balancer_file = os.path.join(work_dir, + "moe_load_balancer.yaml") + moe_load_balancer_config = { + 'num_slots': eplb_num_slots, + 'layer_updates_per_iter': 1 + } + with open(moe_load_balancer_file, "w") as f: + yaml.dump(moe_load_balancer_config, + f, + default_flow_style=False, + sort_keys=False) + gen_config['moe_config']['load_balancer'] = moe_load_balancer_file + + if mtp_size > 0: + ctx_config['speculative_config'] = { + 'decoding_type': 'MTP', + 'num_nextn_predict_layers': mtp_size + } + gen_config['speculative_config'] = { + 'decoding_type': 'MTP', + 'num_nextn_predict_layers': mtp_size + } + + ctx_config_file = os.path.join(work_dir, "ctx_config.yaml") + gen_config_file = os.path.join(work_dir, "gen_config.yaml") + with open(ctx_config_file, "w") as f: + yaml.dump(ctx_config, f, default_flow_style=False, sort_keys=False) + with open(gen_config_file, "w") as f: + yaml.dump(gen_config, f, default_flow_style=False, sort_keys=False) + + print( + f"ctx_config_file: {ctx_config_file} gen_config_file: {gen_config_file} generated successfully" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--work_dir", + type=str, + default="logs", + help="Work directory") + parser.add_argument("--ctx_tp_size", + type=int, + default=4, + help="Tensor parallel size for context servers") + parser.add_argument("--ctx_pp_size", + type=int, + default=1, + help="Pipeline parallel size for context servers") + parser.add_argument("--ctx_batch_size", + type=int, + default=1, + help="Batch size for context servers") + parser.add_argument("--ctx_max_num_tokens", + type=int, + default=8192, + help="Max number of tokens for context servers") + parser.add_argument("--ctx_max_seq_len", + type=int, + default=8192, + help="Max sequence length for context servers") + parser.add_argument("--ctx_free_gpu_memory_fraction", + type=float, + default=0.75, + help="Free GPU memory fraction for context servers") + parser.add_argument("--ctx_enable_attention_dp", + dest='ctx_enable_attention_dp', + action='store_true', + help="Enable attention DP for context servers") + parser.add_argument("--gen_tp_size", + type=int, + default=8, + help="Tensor parallel size for generation servers") + parser.add_argument("--gen_pp_size", + type=int, + default=1, + help="Pipeline parallel size for generation servers") + parser.add_argument("--gen_batch_size", + type=int, + default=256, + help="Batch size for generation servers") + parser.add_argument("--gen_max_num_tokens", + type=int, + default=256, + help="Max number of tokens for generation servers") + parser.add_argument("--gen_max_seq_len", + type=int, + default=9216, + help="Max sequence length for generation servers") + parser.add_argument("--gen_enable_attention_dp", + dest='gen_enable_attention_dp', + action='store_true', + help="Enable attention DP for generation servers") + parser.add_argument("--gen_gpu_memory_fraction", + type=float, + default=0.8, + help="GPU memory fraction for generation servers") + parser.add_argument("--eplb_num_slots", + type=int, + default=0, + help="Number of slots for eplb") + parser.add_argument("--mtp_size", + type=int, + default=0, + help="Number of nextn layers for MTP") + parser.add_argument("--cache_transceiver_max_num_tokens", + type=int, + default=8448, + help="Max number of tokens for cache transceiver") + + args = parser.parse_args() + + gen_config_file(args.work_dir, args.ctx_tp_size, args.ctx_pp_size, args.ctx_batch_size, + args.ctx_max_num_tokens, args.ctx_max_seq_len, + args.ctx_free_gpu_memory_fraction, + args.ctx_enable_attention_dp, args.gen_tp_size, args.gen_pp_size, + args.gen_batch_size, args.gen_max_num_tokens, + args.gen_max_seq_len, args.gen_enable_attention_dp, + args.gen_gpu_memory_fraction, args.eplb_num_slots, + args.mtp_size, args.cache_transceiver_max_num_tokens) diff --git a/examples/disaggregated/slurm/benchmark/gen_yaml.py b/examples/disaggregated/slurm/benchmark/gen_yaml.py deleted file mode 100644 index e0ea7dd4369..00000000000 --- a/examples/disaggregated/slurm/benchmark/gen_yaml.py +++ /dev/null @@ -1,395 +0,0 @@ -import argparse -import os -import re -from typing import Dict, List - -import yaml - - -def process_node_and_task() -> tuple[int, List[str], List[str]]: - """ - Process SLURM node and task environment variables. - - Returns: - tuple: (max_tasks_per_node, nodes, task_nodes) - """ - slurm_job_nodelist = os.getenv('SLURM_JOB_NODELIST', '') - print(f"SLURM_JOB_NODELIST: {slurm_job_nodelist}") - if not slurm_job_nodelist: - raise ValueError(f"Environment variable SLURM_JOB_NODELIST not found.") - - slurm_tasks_per_node = os.getenv('SLURM_TASKS_PER_NODE', '') - print(f"SLURM_TASKS_PER_NODE: {slurm_tasks_per_node}") - if not slurm_tasks_per_node: - raise ValueError( - f"Environment variable SLURM_TASKS_PER_NODE not found.") - - # Generate list of nodes - if '[' in slurm_job_nodelist: - # Handle nodelist with range format - node_prefix = slurm_job_nodelist.split('[')[ - 0] # Extract everything before '[' - node_range = re.search(r'\[(.*?)\]', slurm_job_nodelist).group(1) - nodes = [] - for part in node_range.split(','): - if '-' in part: - start, end = part.split('-') - # Get the width of the number format from the first number - width = len(start) - # Convert to integers after getting the width - start, end = int(start), int(end) - # Format numbers with leading zeros - nodes.extend([ - f"{node_prefix}{str(i).zfill(width)}" - for i in range(start, end + 1) - ]) - else: - # Preserve the original format for single numbers - nodes.append(f"{node_prefix}{part}") - else: - # Handle single node format - nodes = [slurm_job_nodelist] - print(f"Nodes: {nodes}") - - # Generate tasks per node - tasks_per_node = [] - for part in slurm_tasks_per_node.split(','): - if '(x' in part: - count, repeat = map(int, re.findall(r'\d+', part)) - tasks_per_node.extend([count] * repeat) - else: - tasks_per_node.append(int(part)) - print(f"Tasks per node: {tasks_per_node}") - - if (len(tasks_per_node) != len(nodes)): - raise ValueError( - f"Number of nodes and tasks per node do not match. Number of nodes: {len(nodes)}, Number of tasks per node: {len(tasks_per_node)}" - ) - - max_tasks_per_node = max(tasks_per_node) - task_nodes = [] - for node, tasks in zip(nodes, tasks_per_node): - task_nodes.extend([node] * tasks) - - return max_tasks_per_node, nodes, task_nodes - - -def generate_urls(ctx_or_gen: str, - num_instances: int, - tensor_parallel_size: int, - pipeline_parallel_size: int, - max_tasks_per_node: int, - nodes: List[str], - task_nodes: List[str], - node_to_port: Dict[str, int], - task_nodes_offset: int = 0) -> tuple[List[str], int]: - """ - Generate URLs for context or generation servers. - - Returns: - tuple: (urls, updated_task_nodes_offset) - """ - urls = [] - - for instance in range(num_instances): - tasks_needed = tensor_parallel_size * pipeline_parallel_size - - if (task_nodes_offset + tasks_needed) > len(task_nodes): - print(f"{ctx_or_gen} urls so far: {urls}") - raise ValueError( - f"For {ctx_or_gen} instance {instance}, there are not enough tasks available. task_nodes_offset: {task_nodes_offset}, tasks_needed: {tasks_needed}, len(task_nodes): {len(task_nodes)}" - ) - - min_node = (tasks_needed + max_tasks_per_node - 1) // max_tasks_per_node - instance_nodes = set(task_nodes[task_nodes_offset:task_nodes_offset + - tasks_needed]) - if len(instance_nodes) > min_node: - raise ValueError( - f"Tasks for a instance {instance} of {ctx_or_gen} instances use more node than expected. Nodes used: {instance_nodes}, number of nodes expected: {min_node}, max_tasks_per_node: {max_tasks_per_node}" - ) - - node = task_nodes[task_nodes_offset] - port = node_to_port[node] - node_to_port[node] += 1 - task_nodes_offset += tasks_needed - - urls.append(f"{node}:{port}") - - print(f"{ctx_or_gen} urls: {urls}") - return urls, task_nodes_offset - - -def gen_config_file(config_path: str, - model_path: str, - num_ctx_servers: int, - ctx_tp_size: int, - ctx_pp_size: int, - ctx_batch_size: int, - ctx_max_num_tokens: int, - ctx_max_seq_len: int, - ctx_free_gpu_memory_fraction: float, - ctx_enable_attention_dp: bool, - num_gen_servers: int, - gen_tp_size: int, - gen_pp_size: int, - gen_batch_size: int, - gen_max_num_tokens: int, - gen_max_seq_len: int, - gen_enable_attention_dp: bool, - gen_gpu_memory_fraction: float, - eplb_num_slots: int, - mtp_size: int = 0, - worker_start_port: int = 8001, - server_port: int = 8000, - cache_transceiver_max_num_tokens: int = 4608) -> None: - """ - Generate configuration YAML file for disaggregated inference. - - Args: - config_path: Path to save the config file - model_path: Path to the model - num_ctx_servers: Number of context servers - ctx_tp_size: Tensor parallel size for context servers - ctx_pp_size: Pipeline parallel size for context servers - ctx_batch_size: Batch size for context servers - ctx_max_num_tokens: Max number of tokens for context servers - ctx_max_seq_len: Max sequence length for context servers - ctx_free_gpu_memory_fraction: Free GPU memory fraction for context servers - ctx_enable_attention_dp: Enable attention DP for context servers - num_gen_servers: Number of generation servers - gen_tp_size: Tensor parallel size for generation servers - gen_pp_size: Pipeline parallel size for generation servers - gen_batch_size: Batch size for generation servers - gen_max_num_tokens: Max number of tokens for generation servers - gen_enable_attention_dp: Enable attention DP for generation servers - gen_gpu_memory_fraction: GPU memory fraction for generation servers - eplb_num_slots: Number of slots for eplb - worker_start_port: Start port for workers - server_port: Server port - """ - gen_cuda_graph_batch_sizes = [ - 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024, 2048, gen_batch_size - ] - - gen_moe_backend = "CUTLASS" - if gen_tp_size >= 16 and gen_enable_attention_dp: - gen_moe_backend = "WIDEEP" - if not gen_enable_attention_dp: - gen_moe_backend = "TRTLLM" - - config = { - 'model': model_path, - 'hostname': 'localhost', - 'port': server_port, - 'backend': 'pytorch', - 'context_servers': { - 'num_instances': num_ctx_servers, - 'max_batch_size': ctx_batch_size, - 'max_num_tokens': ctx_max_num_tokens, - 'max_seq_len': ctx_max_seq_len, - 'free_gpu_memory_fraction': ctx_free_gpu_memory_fraction, - 'tensor_parallel_size': ctx_tp_size, - 'moe_expert_parallel_size': ctx_tp_size, - 'enable_attention_dp': ctx_enable_attention_dp, - 'pipeline_parallel_size': ctx_pp_size, - 'print_iter_log': True, - 'disable_overlap_scheduler': True, - 'kv_cache_config': { - 'enable_block_reuse': False, - 'free_gpu_memory_fraction': ctx_free_gpu_memory_fraction, - 'dtype': 'fp8', - }, - 'cache_transceiver_config': { - 'max_tokens_in_buffer': cache_transceiver_max_num_tokens, - 'backend': 'DEFAULT', - }, - }, - 'generation_servers': { - 'num_instances': num_gen_servers, - 'tensor_parallel_size': gen_tp_size, - 'moe_expert_parallel_size': gen_tp_size, - 'enable_attention_dp': gen_enable_attention_dp, - 'pipeline_parallel_size': gen_pp_size, - 'max_batch_size': gen_batch_size, - 'max_num_tokens': gen_max_num_tokens, - 'max_seq_len': gen_max_seq_len, - 'free_gpu_memory_fraction': gen_gpu_memory_fraction, - 'cuda_graph_config': { - 'enable_padding': True, - 'batch_sizes': gen_cuda_graph_batch_sizes, - }, - 'print_iter_log': True, - 'kv_cache_config': { - 'enable_block_reuse': False, - 'free_gpu_memory_fraction': gen_gpu_memory_fraction, - 'dtype': 'fp8', - }, - 'moe_config': { - 'backend': gen_moe_backend, - }, - 'cache_transceiver_config': { - 'max_tokens_in_buffer': cache_transceiver_max_num_tokens, - 'backend': 'DEFAULT', - }, - 'stream_interval': 20, - } - } - - # Process nodes and generate URLs - max_tasks_per_node, nodes, task_nodes = process_node_and_task() - node_ports = {node: worker_start_port for node in nodes} - - # Generate URLs for context and generation servers - ctx_urls, task_nodes_offset = generate_urls("ctx", num_ctx_servers, - ctx_tp_size, ctx_pp_size, - max_tasks_per_node, nodes, - task_nodes, node_ports) - if num_ctx_servers > 0: - config['context_servers']['urls'] = ctx_urls - - gen_urls, _ = generate_urls("gen", num_gen_servers, gen_tp_size, - gen_pp_size, max_tasks_per_node, nodes, - task_nodes, node_ports, task_nodes_offset) - config['generation_servers']['urls'] = gen_urls - - # set the hostname to the first node - config['hostname'] = nodes[0] - - if gen_tp_size == 8 and not gen_enable_attention_dp: - config['generation_servers']['allreduce_strategy'] = "MNNVL" - - if eplb_num_slots > 0: - moe_load_balancer_file = os.path.join(os.path.dirname(config_path), - "moe_load_balancer.yaml") - moe_load_balancer_config = { - 'num_slots': eplb_num_slots, - 'layer_updates_per_iter': 1 - } - with open(moe_load_balancer_file, "w") as f: - yaml.dump(moe_load_balancer_config, - f, - default_flow_style=False, - sort_keys=False) - config['generation_servers']['moe_config'][ - 'load_balancer'] = moe_load_balancer_file - - if mtp_size > 0: - config['context_servers']['speculative_config'] = { - 'decoding_type': 'MTP', - 'num_nextn_predict_layers': mtp_size - } - config['generation_servers']['speculative_config'] = { - 'decoding_type': 'MTP', - 'num_nextn_predict_layers': mtp_size - } - - # Write config to file - with open(config_path, 'w') as f: - yaml.dump(config, f, default_flow_style=False, sort_keys=False) - - -# gen main and args -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default="/tmp/config.yaml") - parser.add_argument("--model", - type=str, - required=True, - help="Path to the model") - parser.add_argument("--num_ctx_servers", - type=int, - required=True, - help="Number of context servers") - parser.add_argument("--ctx_tp_size", - type=int, - required=True, - help="Tensor parallel size for context servers") - parser.add_argument("--ctx_pp_size", - type=int, - default=1, - help="Pipeline parallel size for context servers") - parser.add_argument("--ctx_batch_size", - type=int, - required=True, - help="Batch size for context servers") - parser.add_argument("--ctx_max_num_tokens", - type=int, - required=True, - help="Max number of tokens for context servers") - parser.add_argument("--ctx_max_seq_len", - type=int, - required=True, - help="Max sequence length for context servers") - parser.add_argument("--ctx_free_gpu_memory_fraction", - type=float, - required=True, - help="Free GPU memory fraction for context servers") - parser.add_argument("--ctx_enable_attention_dp", - dest='ctx_enable_attention_dp', - action='store_true', - help="Enable attention DP for context servers") - parser.add_argument("--num_gen_servers", - type=int, - required=True, - help="Number of generation servers") - parser.add_argument("--gen_tp_size", - type=int, - required=True, - help="Tensor parallel size for generation servers") - parser.add_argument("--gen_pp_size", - type=int, - default=1, - help="Pipeline parallel size for generation servers") - parser.add_argument("--gen_batch_size", - type=int, - required=True, - help="Batch size for generation servers") - parser.add_argument("--gen_max_num_tokens", - type=int, - required=True, - help="Max number of tokens for generation servers") - parser.add_argument("--gen_max_seq_len", - type=int, - required=True, - help="Max sequence length for generation servers") - parser.add_argument("--gen_enable_attention_dp", - dest='gen_enable_attention_dp', - action='store_true', - help="Enable attention DP for generation servers") - parser.add_argument("--gen_gpu_memory_fraction", - type=float, - required=True, - help="GPU memory fraction for generation servers") - parser.add_argument("--eplb_num_slots", - type=int, - default=0, - help="Number of slots for eplb") - parser.add_argument("--mtp_size", - type=int, - default=0, - help="Number of nextn layers for MTP") - parser.add_argument("--worker_start_port", - type=int, - default=8336, - help="Start port for workers") - parser.add_argument("--server_port", - type=int, - default=8333, - help="Server port") - parser.add_argument("--cache_transceiver_max_num_tokens", - type=int, - default=4608, - help="Max number of tokens for cache transceiver") - - args = parser.parse_args() - - gen_config_file(args.config, args.model, args.num_ctx_servers, - args.ctx_tp_size, args.ctx_pp_size, args.ctx_batch_size, - args.ctx_max_num_tokens, args.ctx_max_seq_len, - args.ctx_free_gpu_memory_fraction, - args.ctx_enable_attention_dp, args.num_gen_servers, - args.gen_tp_size, args.gen_pp_size, args.gen_batch_size, - args.gen_max_num_tokens, args.gen_max_seq_len, - args.gen_enable_attention_dp, args.gen_gpu_memory_fraction, - args.eplb_num_slots, args.mtp_size, args.worker_start_port, - args.server_port, args.cache_transceiver_max_num_tokens) diff --git a/examples/disaggregated/slurm/benchmark/run_benchmark.sh b/examples/disaggregated/slurm/benchmark/run_benchmark.sh index bca7657446c..d821a10b024 100644 --- a/examples/disaggregated/slurm/benchmark/run_benchmark.sh +++ b/examples/disaggregated/slurm/benchmark/run_benchmark.sh @@ -1,8 +1,7 @@ #!/bin/bash - -# Add error handling -set -e set -u +set -e +set -x trap 'echo "Error occurred at line $LINENO"; exit 1' ERR # Add parameter validation @@ -26,10 +25,7 @@ if [[ ${SLURM_PROCID} != "0" ]]; then exit 0 fi -echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}" - -set -x -config_file=${log_path}/config.yaml +config_file=${log_path}/server_config.yaml # check if the config file exists every 10 seconds timeout 1800 seconds timeout=1800 @@ -82,14 +78,26 @@ done # try client do_get_logs(){ - worker_log_path=$1 + log_path=$1 output_folder=$2 - grep -a "'num_ctx_requests': 0, 'num_ctx_tokens': 0" ${worker_log_path} > ${output_folder}/gen_only.txt || true - grep -a "'num_generation_tokens': 0" ${worker_log_path} > ${output_folder}/ctx_only.txt || true + + for gen_file in ${log_path}/output_gen_*.log; do + if [ -f "$gen_file" ]; then + index=$(basename "$gen_file" | sed 's/output_gen_\(.*\)\.log/\1/') + grep -a "'num_ctx_requests': 0, 'num_ctx_tokens': 0" "$gen_file" > "${output_folder}/gen_only_${index}.txt" || true + fi + done + + for ctx_file in ${log_path}/output_ctx_*.log; do + if [ -f "$ctx_file" ]; then + index=$(basename "$ctx_file" | sed 's/output_ctx_\(.*\)\.log/\1/') + grep -a "'num_generation_tokens': 0" "$ctx_file" > "${output_folder}/ctx_only_${index}.txt" || true + fi + done } # run the loadgen -cp ${log_path}/output_workers.log ${log_path}/workers_start.log +echo "Starting benchmark..." for concurrency in ${concurrency_list}; do mkdir -p ${log_path}/concurrency_${concurrency} max_count=$((${concurrency} * ${multi_round})) @@ -110,8 +118,7 @@ for concurrency in ${concurrency_list}; do --no-test-input \ $(if [ "${streaming}" = "false" ]; then echo "--non-streaming"; fi) - do_get_logs ${log_path}/output_workers.log ${log_path}/concurrency_${concurrency} - echo "" > ${log_path}/output_workers.log + do_get_logs ${log_path} ${log_path}/concurrency_${concurrency} echo "done for ${concurrency} in folder ${log_path}/concurrency_${concurrency}" done diff --git a/examples/disaggregated/slurm/benchmark/start_server.sh b/examples/disaggregated/slurm/benchmark/start_server.sh index beb4b2f18c8..0f4c498dc1d 100644 --- a/examples/disaggregated/slurm/benchmark/start_server.sh +++ b/examples/disaggregated/slurm/benchmark/start_server.sh @@ -1,34 +1,17 @@ -#! /bin/bash +#!/bin/bash +set -u +set -e +set -x -echo "commit id: $TRT_LLM_GIT_COMMIT" -echo "ucx info: $(ucx_info -v)" -echo "hostname: $(hostname)" +num_ctx_servers=$1 +num_gen_servers=$2 +work_dir=$3 +script_dir=$4 -hostname=$(hostname) -short_hostname=$(echo "$hostname" | awk -F'.' '{print $1}') -echo "short_hostname: ${short_hostname}" +python3 ${script_dir}/gen_server_config.py \ + --num_ctx_servers ${num_ctx_servers} \ + --num_gen_servers ${num_gen_servers} \ + --work_dir ${work_dir} +echo "server config generated to ${work_dir}/server_config.yaml" -config_file=$1 - -# Check and replace hostname settings in config_file -if [ -f "$config_file" ]; then - # Use sed to find hostname line and check if replacement is needed - if grep -q "hostname:" "$config_file"; then - # Extract current hostname value from config - current_hostname=$(grep "hostname:" "$config_file" | sed 's/.*hostname:[ ]*//' | awk '{print $1}') - - if [ "$current_hostname" != "$short_hostname" ]; then - echo "Replacing hostname '$current_hostname' with '$short_hostname' in $config_file" - # Use sed to replace hostname value - sed -i "s/hostname:[ ]*[^ ]*/hostname: $short_hostname/" "$config_file" - else - echo "Hostname '$current_hostname' already matches '$short_hostname', no change needed" - fi - else - echo "No hostname setting found in $config_file" - fi -else - echo "Config file $config_file not found" -fi - -trtllm-serve disaggregated -c ${config_file} -t 1800 -r 7200 +trtllm-serve disaggregated -c ${work_dir}/server_config.yaml -t 7200 -r 7200 diff --git a/examples/disaggregated/slurm/benchmark/start_worker.sh b/examples/disaggregated/slurm/benchmark/start_worker.sh index b10099aa033..efa413ff796 100644 --- a/examples/disaggregated/slurm/benchmark/start_worker.sh +++ b/examples/disaggregated/slurm/benchmark/start_worker.sh @@ -1,16 +1,23 @@ #! /bin/bash +set -u +set -e +set -x + +role=$1 +instance_id=$2 +model_path=$3 +port=$4 +benchmark_mode=$5 +concurrency=$6 +enable_pdl=$7 +work_dir=$8 +nsys_folder=${9:-} -config_file=$1 -enable_pdl=$2 -ctx_gpus=$3 -benchmark_mode=$4 -concurrency=$5 -work_dir=$6 unset UCX_TLS -echo "config_file: ${config_file}, enable_pdl: ${enable_pdl}, ctx_gpus: ${ctx_gpus}, work_dir: ${work_dir}" +echo "concurrency: ${concurrency}, enable_pdl: ${enable_pdl}, work_dir: ${work_dir}" +echo "SLURM_PROCID: ${SLURM_PROCID}, hostname: $(hostname), instance_id: ${instance_id}" export TLLM_LOG_LEVEL=INFO -export TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER=1 if [ "${enable_pdl}" = "true" ]; then export TRTLLM_ENABLE_PDL=1 @@ -21,21 +28,43 @@ if [ "${benchmark_mode}" = "gen_only" ]; then export TLLM_BENCHMARK_REQ_QUEUES_SIZE=${concurrency} fi -#check if work_dir is provided -if [ -z "${work_dir}" ]; then +if [ "${role}" = "CTX" ]; then + config_file=${work_dir}/ctx_config.yaml +elif [ "${role}" = "GEN" ]; then + config_file=${work_dir}/gen_config.yaml +else + echo "Invalid role: ${role}" + exit 1 +fi +echo "config_file: ${config_file}" + +# save the hostname to a file + +# if SLURM_NODEID is 0 +if [ "${SLURM_NODEID}" = "0" ]; then + mkdir -p ${work_dir}/hostnames/ + echo $(hostname) > ${work_dir}/hostnames/${role}_${instance_id}.txt + echo "hostname saved to ${work_dir}/hostnames/${role}_${instance_id}.txt" +fi + +#check if nsys_folder is provided +if [ -z "${nsys_folder:-}" ]; then echo "nsys is not enabled, start normal flow" - trtllm-serve disaggregated_mpi_worker -c ${config_file} + trtllm-llmapi-launch trtllm-serve ${model_path} --host $(hostname) --port ${port} --extra_llm_api_options ${config_file} else nsys_prefix="" - nsys_file=${work_dir}/nsys_worker_proc_${SLURM_PROCID} + nsys_file=${nsys_folder}/nsys_worker_proc_${instance_id}_${SLURM_PROCID} export TLLM_PROFILE_RECORD_GC=1 export TLLM_NVTX_DEBUG=1 - if [ ${SLURM_PROCID} -ge ${ctx_gpus} ]; then + if [ "${role}" = "GEN" ]; then export TLLM_PROFILE_START_STOP=200-250 nsys_prefix="nsys profile -e \"NSYS_MPI_STORE_TEAMS_PER_RANK=1\" -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none" echo "nsys_prefix: ${nsys_prefix}" - else + elif [ "${role}" = "CTX" ]; then echo "nsys is not enabled on ctx_gpus" fi - ${nsys_prefix} trtllm-serve disaggregated_mpi_worker -c ${config_file} + trtllm-llmapi-launch ${nsys_prefix} \ + trtllm-serve ${model_path} \ + --host $(hostname) --port ${port} \ + --extra_llm_api_options ${config_file} fi From af7dde05792eec585e3cf6ebfb8fc1c645f45f14 Mon Sep 17 00:00:00 2001 From: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Date: Wed, 27 Aug 2025 03:09:51 -0700 Subject: [PATCH 2/2] Fix style Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> --- .../slurm/benchmark/gen_worker_config.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/disaggregated/slurm/benchmark/gen_worker_config.py b/examples/disaggregated/slurm/benchmark/gen_worker_config.py index adee5cbe72c..c37b8ab78c0 100644 --- a/examples/disaggregated/slurm/benchmark/gen_worker_config.py +++ b/examples/disaggregated/slurm/benchmark/gen_worker_config.py @@ -223,11 +223,12 @@ def gen_config_file(work_dir: str, args = parser.parse_args() - gen_config_file(args.work_dir, args.ctx_tp_size, args.ctx_pp_size, args.ctx_batch_size, - args.ctx_max_num_tokens, args.ctx_max_seq_len, - args.ctx_free_gpu_memory_fraction, - args.ctx_enable_attention_dp, args.gen_tp_size, args.gen_pp_size, - args.gen_batch_size, args.gen_max_num_tokens, - args.gen_max_seq_len, args.gen_enable_attention_dp, - args.gen_gpu_memory_fraction, args.eplb_num_slots, - args.mtp_size, args.cache_transceiver_max_num_tokens) + gen_config_file(args.work_dir, args.ctx_tp_size, args.ctx_pp_size, + args.ctx_batch_size, args.ctx_max_num_tokens, + args.ctx_max_seq_len, args.ctx_free_gpu_memory_fraction, + args.ctx_enable_attention_dp, args.gen_tp_size, + args.gen_pp_size, args.gen_batch_size, + args.gen_max_num_tokens, args.gen_max_seq_len, + args.gen_enable_attention_dp, args.gen_gpu_memory_fraction, + args.eplb_num_slots, args.mtp_size, + args.cache_transceiver_max_num_tokens)