diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index e87aed027317..4275cfb8d45e 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -1473,7 +1473,7 @@ steps: - tests/v1/kv_connector/nixl_integration/ commands: - uv pip install --system -r /vllm-workspace/requirements/kv_connectors_rocm.txt - - bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh --attention-backend ROCM_ATTN + - ROCM_ATTN=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh - label: DP EP NixlConnector PD accuracy tests (Distributed) # 15min mirror_hardwares: [amdexperimental, amdproduction] @@ -1487,7 +1487,7 @@ steps: - tests/v1/kv_connector/nixl_integration/ commands: - uv pip install --system -r /vllm-workspace/requirements/kv_connectors_rocm.txt - - DP_EP=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh --attention-backend ROCM_ATTN + - DP_EP=1 ROCM_ATTN=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh ##### multi gpus test ##### ##### A100 test ##### diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index a9c2c47eba47..2e25e2f1ac32 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -43,7 +43,12 @@ run_tests() { } # Run tests -run_tests "default backend" "" +if [[ -n "${ROCM_ATTN:-}" ]]; then + echo "ROCM_ATTN is set, running with --attention-backend ROCM_ATTN" + run_tests "ROCM_ATTN backend" "--attention-backend ROCM_ATTN" +else + run_tests "default backend" "" +fi # Check if FLASHINFER is set (non-empty) if [[ -n "${FLASHINFER:-}" ]]; then diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 84a994de584d..0d772242884b 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -287,10 +287,13 @@ def get_attn_backend_cls( return AttentionBackendEnum.ROCM_AITER_FA.get_path() # Priority 3: Check for ROCM_ATTN (prefill-decode split) - from vllm.config import get_current_vllm_config + from vllm.config import get_current_vllm_config_or_none - vllm_config = get_current_vllm_config() - if vllm_config.attention_config.use_prefill_decode_attention: + vllm_config = get_current_vllm_config_or_none() + if ( + vllm_config is not None + and vllm_config.attention_config.use_prefill_decode_attention + ): logger.info("Using Rocm Attention backend.") return AttentionBackendEnum.ROCM_ATTN.get_path()