diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index e90b72a7cf24..c17784e0a263 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -13,6 +13,8 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2 # Find the git repository root directory GIT_ROOT=$(git rev-parse --show-toplevel) +SMI_BIN=$(which nvidia-smi || which rocm-smi) + # Trap the SIGINT signal (triggered by Ctrl+C) trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT @@ -44,6 +46,13 @@ get_model_args() { echo "$extra_args" } +get_num_gpus() { + if [[ "$SMI_BIN" == *"nvidia"* ]]; then + echo "$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)" + else + echo "$($SMI_BIN -l | grep GPU | wc -l)" + fi +} # Function to run tests for a specific model run_tests_for_model() { @@ -64,7 +73,7 @@ run_tests_for_model() { # Start prefill instances for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do # Calculate GPU ID - we'll distribute across available GPUs - GPU_ID=$((i % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) + GPU_ID=$((i % $(get_num_gpus))) # Calculate port number (base port + instance number) PORT=$((8100 + i)) # Calculate side channel port @@ -96,7 +105,7 @@ run_tests_for_model() { # Start decode instances for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do # Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs - GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) + GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(get_num_gpus))) # Calculate port number (base port + instance number) PORT=$((8200 + i)) # Calculate side channel port