diff --git a/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm index 2fa0ec588b6..2235767fa92 100644 --- a/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm +++ b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm @@ -1,59 +1,60 @@ #!/bin/bash set -euo pipefail -# Parse arguments -# Hardware configuration -gpus_per_node=${1} -numa_bind=${2} -ctx_nodes=${3} # Number of nodes needed for ctx workers -gen_nodes=${4} # Number of nodes needed for gen workers -ctx_world_size=${5} # World size for ctx workers -gen_world_size=${6} # World size for gen workers - -# Worker configuration -num_ctx_servers=${7} -ctx_config_path=${8} -num_gen_servers=${9} -gen_config_path=${10} -concurrency_list=${11} - -# Sequence and benchmark parameters -isl=${12} -osl=${13} -multi_round=${14} -benchmark_ratio=${15} -streaming=${16} -use_nv_sa_benchmark=${17} -benchmark_mode=${18} -cache_max_tokens=${19} - -# Environment and paths -dataset_file=${20} -model_path=${21} -trtllm_repo=${22} -work_dir=${23} -full_logdir=${24} -container_mount=${25} -container_image=${26} -build_wheel=${27} -trtllm_wheel_path=${28} - -# Profiling -nsys_on=${29} -ctx_profile_range=${30} -gen_profile_range=${31} - -# Accuracy evaluation -enable_accuracy_test=${32} -accuracy_model=${33} -accuracy_tasks=${34} -model_args_extra=${35} - -# Worker environment variables -worker_env_var=${36} - -# Server environment variables -server_env_var=${37} +# Parse named arguments +while [[ $# -gt 0 ]]; do + case $1 in + # Hardware configuration + --gpus-per-node) gpus_per_node="$2"; shift 2 ;; + --numa-bind) numa_bind="$2"; shift 2 ;; + --ctx-nodes) ctx_nodes="$2"; shift 2 ;; + --gen-nodes) gen_nodes="$2"; shift 2 ;; + --ctx-world-size) ctx_world_size="$2"; shift 2 ;; + --gen-world-size) gen_world_size="$2"; shift 2 ;; + # Worker configuration + --num-ctx-servers) num_ctx_servers="$2"; shift 2 ;; + --ctx-config-path) ctx_config_path="$2"; shift 2 ;; + --num-gen-servers) num_gen_servers="$2"; shift 2 ;; + --gen-config-path) gen_config_path="$2"; shift 2 ;; + --concurrency-list) concurrency_list="$2"; shift 2 ;; + # Sequence and benchmark parameters + --isl) isl="$2"; shift 2 ;; + --osl) osl="$2"; shift 2 ;; + --multi-round) multi_round="$2"; shift 2 ;; + --benchmark-ratio) benchmark_ratio="$2"; shift 2 ;; + --streaming) streaming="$2"; shift 2 ;; + --use-nv-sa-benchmark) use_nv_sa_benchmark="$2"; shift 2 ;; + --benchmark-mode) benchmark_mode="$2"; shift 2 ;; + --cache-max-tokens) cache_max_tokens="$2"; shift 2 ;; + # Environment and paths + --dataset-file) dataset_file="$2"; shift 2 ;; + --model-path) model_path="$2"; shift 2 ;; + --trtllm-repo) trtllm_repo="$2"; shift 2 ;; + --work-dir) work_dir="$2"; shift 2 ;; + --full-logdir) full_logdir="$2"; shift 2 ;; + --container-mount) container_mount="$2"; shift 2 ;; + --container-image) container_image="$2"; shift 2 ;; + --build-wheel) build_wheel="$2"; shift 2 ;; + --trtllm-wheel-path) trtllm_wheel_path="$2"; shift 2 ;; + # Profiling + --nsys-on) nsys_on="$2"; shift 2 ;; + --ctx-profile-range) ctx_profile_range="$2"; shift 2 ;; + --gen-profile-range) gen_profile_range="$2"; shift 2 ;; + # Accuracy evaluation + --enable-accuracy-test) enable_accuracy_test="$2"; shift 2 ;; + --accuracy-model) accuracy_model="$2"; shift 2 ;; + --accuracy-tasks) accuracy_tasks="$2"; shift 2 ;; + --model-args-extra) model_args_extra="$2"; shift 2 ;; + # Worker environment variables + --worker-env-var) worker_env_var="$2"; shift 2 ;; + # Server environment variables + --server-env-var) server_env_var="$2"; shift 2 ;; + *) + echo "Unknown argument: $1" + exit 1 + ;; + esac +done # Print all parsed arguments echo "Parsed arguments:" diff --git a/examples/disaggregated/slurm/benchmark/submit.py b/examples/disaggregated/slurm/benchmark/submit.py index 9aa00356d2c..12ee15aba35 100644 --- a/examples/disaggregated/slurm/benchmark/submit.py +++ b/examples/disaggregated/slurm/benchmark/submit.py @@ -150,6 +150,7 @@ def submit_job(config, log_dir): save_worker_config(config, gen_config_path, 'gen') # Prepare sbatch command + # yapf: disable cmd = [ 'sbatch', f'--partition={slurm_config["partition"]}', @@ -163,59 +164,60 @@ def submit_job(config, log_dir): *([arg for arg in slurm_config['extra_args'].split() if arg]), slurm_config['script_file'], # Hardware configuration - str(hw_config['gpus_per_node']), - str(slurm_config['numa_bind']).lower(), - str(ctx_nodes), # Number of nodes needed for ctx workers - str(gen_nodes), # Number of nodes needed for gen workers - str(ctx_world_size), # World size for ctx workers - str(gen_world_size), # World size for gen workers + '--gpus-per-node', str(hw_config['gpus_per_node']), + '--numa-bind', str(slurm_config['numa_bind']).lower(), + '--ctx-nodes', str(ctx_nodes), # Number of nodes needed for ctx workers + '--gen-nodes', str(gen_nodes), # Number of nodes needed for gen workers + '--ctx-world-size', str(ctx_world_size), # World size for ctx workers + '--gen-world-size', str(gen_world_size), # World size for gen workers # Worker configuration - str(ctx_num), - ctx_config_path, - str(gen_num), - gen_config_path, - config['benchmark']['concurrency_list'], + '--num-ctx-servers', str(ctx_num), + '--ctx-config-path', ctx_config_path, + '--num-gen-servers', str(gen_num), + '--gen-config-path', gen_config_path, + '--concurrency-list', config['benchmark']['concurrency_list'], # Sequence and benchmark parameters - str(config['benchmark']['input_length']), - str(config['benchmark']['output_length']), - str(config['benchmark']['multi_round']), - str(config['benchmark']['benchmark_ratio']), - str(config['benchmark']['streaming']).lower(), - str(config['benchmark']['use_nv_sa_benchmark']).lower(), - config['benchmark']['mode'], - str(config['worker_config']['gen']['cache_transceiver_config'] + '--isl', str(config['benchmark']['input_length']), + '--osl', str(config['benchmark']['output_length']), + '--multi-round', str(config['benchmark']['multi_round']), + '--benchmark-ratio', str(config['benchmark']['benchmark_ratio']), + '--streaming', str(config['benchmark']['streaming']).lower(), + '--use-nv-sa-benchmark', str(config['benchmark']['use_nv_sa_benchmark']).lower(), + '--benchmark-mode', config['benchmark']['mode'], + '--cache-max-tokens', str(config['worker_config']['gen']['cache_transceiver_config'] ['max_tokens_in_buffer']), # Environment and paths - config['benchmark']['dataset_file'], - env_config['model_path'], - env_config['trtllm_repo'], - env_config['work_dir'], - log_dir, # Pass the generated log directory - env_config['container_mount'], - env_config['container_image'], - str(env_config['build_wheel']).lower(), - env_config['trtllm_wheel_path'], + '--dataset-file', config['benchmark']['dataset_file'], + '--model-path', env_config['model_path'], + '--trtllm-repo', env_config['trtllm_repo'], + '--work-dir', env_config['work_dir'], + '--full-logdir', log_dir, + '--container-mount', env_config['container_mount'], + '--container-image', env_config['container_image'], + '--build-wheel', str(env_config['build_wheel']).lower(), + '--trtllm-wheel-path', env_config['trtllm_wheel_path'], # Profiling - str(profiling_config['nsys_on']).lower(), - profiling_config['ctx_profile_range'], - profiling_config['gen_profile_range'], + '--nsys-on', str(profiling_config['nsys_on']).lower(), + '--ctx-profile-range', profiling_config['ctx_profile_range'], + '--gen-profile-range', profiling_config['gen_profile_range'], # Accuracy evaluation - str(config['accuracy']['enable_accuracy_test']).lower(), - config['accuracy']['model'], - config['accuracy']['tasks'], - config['accuracy']['model_args_extra'], + '--enable-accuracy-test', str(config['accuracy']['enable_accuracy_test']).lower(), + '--accuracy-model', config['accuracy']['model'], + '--accuracy-tasks', config['accuracy']['tasks'], + '--model-args-extra', config['accuracy']['model_args_extra'], # Worker environment variables - env_config['worker_env_var'], + '--worker-env-var', env_config['worker_env_var'], # Server environment variables - env_config['server_env_var'] + '--server-env-var', env_config['server_env_var'] ] + # yapf: enable # Submit the job try: