Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 54 additions & 53 deletions examples/disaggregated/slurm/benchmark/disaggr_torch.slurm
Original file line number Diff line number Diff line change
@@ -1,59 +1,60 @@
#!/bin/bash
set -euo pipefail

# Parse arguments
# Hardware configuration
gpus_per_node=${1}
numa_bind=${2}
ctx_nodes=${3} # Number of nodes needed for ctx workers
gen_nodes=${4} # Number of nodes needed for gen workers
ctx_world_size=${5} # World size for ctx workers
gen_world_size=${6} # World size for gen workers

# Worker configuration
num_ctx_servers=${7}
ctx_config_path=${8}
num_gen_servers=${9}
gen_config_path=${10}
concurrency_list=${11}

# Sequence and benchmark parameters
isl=${12}
osl=${13}
multi_round=${14}
benchmark_ratio=${15}
streaming=${16}
use_nv_sa_benchmark=${17}
benchmark_mode=${18}
cache_max_tokens=${19}

# Environment and paths
dataset_file=${20}
model_path=${21}
trtllm_repo=${22}
work_dir=${23}
full_logdir=${24}
container_mount=${25}
container_image=${26}
build_wheel=${27}
trtllm_wheel_path=${28}

# Profiling
nsys_on=${29}
ctx_profile_range=${30}
gen_profile_range=${31}

# Accuracy evaluation
enable_accuracy_test=${32}
accuracy_model=${33}
accuracy_tasks=${34}
model_args_extra=${35}

# Worker environment variables
worker_env_var=${36}

# Server environment variables
server_env_var=${37}
# Parse named arguments
while [[ $# -gt 0 ]]; do
case $1 in
# Hardware configuration
--gpus-per-node) gpus_per_node="$2"; shift 2 ;;
--numa-bind) numa_bind="$2"; shift 2 ;;
--ctx-nodes) ctx_nodes="$2"; shift 2 ;;
--gen-nodes) gen_nodes="$2"; shift 2 ;;
--ctx-world-size) ctx_world_size="$2"; shift 2 ;;
--gen-world-size) gen_world_size="$2"; shift 2 ;;
# Worker configuration
--num-ctx-servers) num_ctx_servers="$2"; shift 2 ;;
--ctx-config-path) ctx_config_path="$2"; shift 2 ;;
--num-gen-servers) num_gen_servers="$2"; shift 2 ;;
--gen-config-path) gen_config_path="$2"; shift 2 ;;
--concurrency-list) concurrency_list="$2"; shift 2 ;;
# Sequence and benchmark parameters
--isl) isl="$2"; shift 2 ;;
--osl) osl="$2"; shift 2 ;;
--multi-round) multi_round="$2"; shift 2 ;;
--benchmark-ratio) benchmark_ratio="$2"; shift 2 ;;
--streaming) streaming="$2"; shift 2 ;;
--use-nv-sa-benchmark) use_nv_sa_benchmark="$2"; shift 2 ;;
--benchmark-mode) benchmark_mode="$2"; shift 2 ;;
--cache-max-tokens) cache_max_tokens="$2"; shift 2 ;;
# Environment and paths
--dataset-file) dataset_file="$2"; shift 2 ;;
--model-path) model_path="$2"; shift 2 ;;
--trtllm-repo) trtllm_repo="$2"; shift 2 ;;
--work-dir) work_dir="$2"; shift 2 ;;
--full-logdir) full_logdir="$2"; shift 2 ;;
--container-mount) container_mount="$2"; shift 2 ;;
--container-image) container_image="$2"; shift 2 ;;
--build-wheel) build_wheel="$2"; shift 2 ;;
--trtllm-wheel-path) trtllm_wheel_path="$2"; shift 2 ;;
# Profiling
--nsys-on) nsys_on="$2"; shift 2 ;;
--ctx-profile-range) ctx_profile_range="$2"; shift 2 ;;
--gen-profile-range) gen_profile_range="$2"; shift 2 ;;
# Accuracy evaluation
--enable-accuracy-test) enable_accuracy_test="$2"; shift 2 ;;
--accuracy-model) accuracy_model="$2"; shift 2 ;;
--accuracy-tasks) accuracy_tasks="$2"; shift 2 ;;
--model-args-extra) model_args_extra="$2"; shift 2 ;;
# Worker environment variables
--worker-env-var) worker_env_var="$2"; shift 2 ;;
# Server environment variables
--server-env-var) server_env_var="$2"; shift 2 ;;
*)
echo "Unknown argument: $1"
exit 1
;;
esac
done

# Print all parsed arguments
echo "Parsed arguments:"
Expand Down
76 changes: 39 additions & 37 deletions examples/disaggregated/slurm/benchmark/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def submit_job(config, log_dir):
save_worker_config(config, gen_config_path, 'gen')

# Prepare sbatch command
# yapf: disable
cmd = [
'sbatch',
f'--partition={slurm_config["partition"]}',
Expand All @@ -163,59 +164,60 @@ def submit_job(config, log_dir):
*([arg for arg in slurm_config['extra_args'].split() if arg]),
slurm_config['script_file'],
# Hardware configuration
str(hw_config['gpus_per_node']),
str(slurm_config['numa_bind']).lower(),
str(ctx_nodes), # Number of nodes needed for ctx workers
str(gen_nodes), # Number of nodes needed for gen workers
str(ctx_world_size), # World size for ctx workers
str(gen_world_size), # World size for gen workers
'--gpus-per-node', str(hw_config['gpus_per_node']),
'--numa-bind', str(slurm_config['numa_bind']).lower(),
'--ctx-nodes', str(ctx_nodes), # Number of nodes needed for ctx workers
'--gen-nodes', str(gen_nodes), # Number of nodes needed for gen workers
'--ctx-world-size', str(ctx_world_size), # World size for ctx workers
'--gen-world-size', str(gen_world_size), # World size for gen workers

# Worker configuration
str(ctx_num),
ctx_config_path,
str(gen_num),
gen_config_path,
config['benchmark']['concurrency_list'],
'--num-ctx-servers', str(ctx_num),
'--ctx-config-path', ctx_config_path,
'--num-gen-servers', str(gen_num),
'--gen-config-path', gen_config_path,
'--concurrency-list', config['benchmark']['concurrency_list'],

# Sequence and benchmark parameters
str(config['benchmark']['input_length']),
str(config['benchmark']['output_length']),
str(config['benchmark']['multi_round']),
str(config['benchmark']['benchmark_ratio']),
str(config['benchmark']['streaming']).lower(),
str(config['benchmark']['use_nv_sa_benchmark']).lower(),
config['benchmark']['mode'],
str(config['worker_config']['gen']['cache_transceiver_config']
'--isl', str(config['benchmark']['input_length']),
'--osl', str(config['benchmark']['output_length']),
'--multi-round', str(config['benchmark']['multi_round']),
'--benchmark-ratio', str(config['benchmark']['benchmark_ratio']),
'--streaming', str(config['benchmark']['streaming']).lower(),
'--use-nv-sa-benchmark', str(config['benchmark']['use_nv_sa_benchmark']).lower(),
'--benchmark-mode', config['benchmark']['mode'],
'--cache-max-tokens', str(config['worker_config']['gen']['cache_transceiver_config']
['max_tokens_in_buffer']),

# Environment and paths
config['benchmark']['dataset_file'],
env_config['model_path'],
env_config['trtllm_repo'],
env_config['work_dir'],
log_dir, # Pass the generated log directory
env_config['container_mount'],
env_config['container_image'],
str(env_config['build_wheel']).lower(),
env_config['trtllm_wheel_path'],
'--dataset-file', config['benchmark']['dataset_file'],
'--model-path', env_config['model_path'],
'--trtllm-repo', env_config['trtllm_repo'],
'--work-dir', env_config['work_dir'],
'--full-logdir', log_dir,
'--container-mount', env_config['container_mount'],
'--container-image', env_config['container_image'],
'--build-wheel', str(env_config['build_wheel']).lower(),
'--trtllm-wheel-path', env_config['trtllm_wheel_path'],

# Profiling
str(profiling_config['nsys_on']).lower(),
profiling_config['ctx_profile_range'],
profiling_config['gen_profile_range'],
'--nsys-on', str(profiling_config['nsys_on']).lower(),
'--ctx-profile-range', profiling_config['ctx_profile_range'],
'--gen-profile-range', profiling_config['gen_profile_range'],

# Accuracy evaluation
str(config['accuracy']['enable_accuracy_test']).lower(),
config['accuracy']['model'],
config['accuracy']['tasks'],
config['accuracy']['model_args_extra'],
'--enable-accuracy-test', str(config['accuracy']['enable_accuracy_test']).lower(),
'--accuracy-model', config['accuracy']['model'],
'--accuracy-tasks', config['accuracy']['tasks'],
'--model-args-extra', config['accuracy']['model_args_extra'],

# Worker environment variables
env_config['worker_env_var'],
'--worker-env-var', env_config['worker_env_var'],

# Server environment variables
env_config['server_env_var']
'--server-env-var', env_config['server_env_var']
]
# yapf: enable

# Submit the job
try:
Expand Down