Skip to content

[ROCm][CI] Attempt to fix the failures under a subgroup of the e2e the test group#29358

Merged
DarkLight1337 merged 11 commits intovllm-project:mainfrom
ROCm:akaratza_async_sched
Dec 10, 2025
Merged

[ROCm][CI] Attempt to fix the failures under a subgroup of the e2e the test group#29358
DarkLight1337 merged 11 commits intovllm-project:mainfrom
ROCm:akaratza_async_sched

Conversation

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

@AndreasKaratzas AndreasKaratzas commented Nov 25, 2025

This PR ensures that during test_async_scheduling, we utilize the TRITON_ATTN backend which is the default attention backend for ROCm.

Test used to verify functionality on ROCm: pytest -v -s tests/v1/e2e/test_async_scheduling.py

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@mergify mergify bot added rocm Related to AMD ROCm v1 labels Nov 25, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a CI failure on ROCm within test_async_scheduling. The root cause was that the test forced the FLEX_ATTENTION backend, which is unsupported on ROCm, leading to a runtime error. The proposed fix introduces a platform-specific check to set a supported attention backend. For ROCm, it now correctly uses TRITON_ATTN, while for other platforms, it preserves the existing behavior of using FLEX_ATTENTION. The change is well-targeted, correct, and should resolve the CI failure.

@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Nov 26, 2025

@AndreasKaratzas

I synced your branch with main and tried running the tests. I am getting 2 failure cases, are they intended to be failing and handled in other PR?

(Worker pid=230405) ERROR 11-26 04:55:53 [multiproc_executor.py:822]   File "/usr/local/lib/python3.12/dist-packages/triton/run
time/jit.py", line 623, in run                                                                                                 
(Worker pid=230405) ERROR 11-26 04:55:53 [multiproc_executor.py:822]     kernel.run(grid_0, grid_1, grid_2, stream, kernel.func
tion, kernel.packed_metadata, launch_metadata,                                                                                 
(Worker pid=230405) ERROR 11-26 04:55:53 [multiproc_executor.py:822]     ^^^^^^^^^^                                            
(Worker pid=230405) ERROR 11-26 04:55:53 [multiproc_executor.py:822]   File "/usr/local/lib/python3.12/dist-packages/triton/com
piler/compiler.py", line 467, in __getattribute__                                                                              
(Worker pid=230405) ERROR 11-26 04:55:53 [multiproc_executor.py:822]     self._init_handles()                                  
(Worker pid=230405) ERROR 11-26 04:55:53 [multiproc_executor.py:822]   File "/usr/local/lib/python3.12/dist-packages/triton/com
piler/compiler.py", line 461, in _init_handles                                                                                 
(Worker pid=230405) ERROR 11-26 04:55:53 [multiproc_executor.py:822]     raise OutOfResources(self.metadata.num_warps * warp_si
ze, self.n_max_threads, "threads")                                                                                             
(Worker pid=230405) ERROR 11-26 04:55:53 [multiproc_executor.py:822] triton.runtime.errors.OutOfResources: out of resource: thr
eads, Required: 2048, Hardware limit: 1024. Reducing block sizes or `num_stages` may help.        

Summary of the tests run

HIP_VISIBLE_DEVICES=7 pytest -v -s tests/v1/e2e/test_async_scheduling.py

tests/v1/e2e/test_async_scheduling.py:179: AssertionError
====================================================== warnings summary =======================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=================================================== short test summary info ===================================================
FAILED tests/v1/e2e/test_async_scheduling.py::test_without_spec_decoding - vllm.v1.engine.exceptions.EngineDeadError: EngineCore encountered an issue. See stack trace (above) for the root cause.
FAILED tests/v1/e2e/test_async_scheduling.py::test_with_spec_decoding - AssertionError: assert False
========================================== 2 failed, 2 warnings in 500.32s (0:08:20) ==========================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

… without affecting other platforms

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

I built the upstream dockers (both base and .rocm) and the issue you attached above did not appear on my MI325. However, accuracy issues appeared, and problems with bad word subgroup. So I refactored the solution. Let me know if you still get errors (I run the test 3 times before I upload it, passed all 3).

For the newly introduced modifications, here are some details:

Platform-specific attention backends

  • ROCm: ROCM_AITER_FA for non-spec-decoding, TRITON_ATTN for spec-decoding
  • Other platforms: FLEX_ATTENTION (unchanged from original)

Platform-specific dtype

  • ROCm: float16 for non-spec-decoding, float32 for spec-decoding (TRITON_ATTN supports higher precision)
  • Other platforms: float32 always (original behavior)

Relaxed tolerances for ROCm

  • Logprob comparison uses rel_tol=5e-2, abs_tol=1e-5 on ROCm since we are using fp16
  • Other platforms use original strict tolerances rel_tol=1e-3, abs_tol=1e-6

Reduced test configurations on ROCm

  • Skip chunk_prefill=True configs (FP variance compounds across chunks)
  • Only run structured_outputs tests (deterministic despite FP differences)

Skip strict logprobs check for ROCm spec-decoding

  • When logprobs are requested during spec-decoding tests on ROCm, skip strict comparison since values can differ slightly even when selected tokens match

I think I'm still respecting the purpose of this particular test with these new changes. I am open to feedback. Even for NVIDIA the only attention backend that works is FLEX_ATTENTION. So the peculiarities of this test as is render almost every attention backend incompatible with it.

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

Btw, commands for test execution:

docker build -f docker/Dockerfile.rocm_base -t rocm/vllm-dev:base . && docker build --no-cache -f docker/Dockerfile.rocm --target test -t vllm-rocm-test:latest .  

docker run --rm -it --device /dev/kfd --device /dev/dri --network=host --shm-size=16gb --group-add video -w /vllm-workspace/tests -e HF_TOKEN=<YOUR_HF_TOKEN> -e PYTHONPATH=/vllm-workspace:$PYTHONPATH -e VLLM_WORKER_MULTIPROC_METHOD=spawn vllm-rocm-test:latest bash -c "pytest -v -s v1/e2e/test_async_scheduling.py"

@AndreasKaratzas AndreasKaratzas changed the title [ROCm][CI] Attempt to fix the failures under the test group [ROCm][CI] Attempt to fix the failures under a subgroup of the e2e the test group Nov 29, 2025
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@mergify mergify bot added the ci/build label Dec 3, 2025

# Data processing
xgrammar==0.1.27
xgrammar @ git+https://github.com/divakar-amd/xgrammar@3272f7c520564858056a60480d5afdf69ae79c84
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional? Can you please explain why you are changing this to @divakar-amd's fork?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Sage, yes this is intentional. The problem is that in xgrammar there is a hard coded WARP SIZE:
divakar-amd/xgrammar@41a849f

So the pip package (or for that matter the upstream xgrammar repo) do not work correctly on ROCm. There is an open PR on xgrammar for this:
mlc-ai/xgrammar#476

But as you can see this PR has been there for more than 2 weeks. So for the test to work for now, we are going with this solution. I am in contact with @divakar-amd and as soon as his PR gets merged, we are going to change the requirements to at least be in parity with CUDA.

if current_platform.is_rocm() and not is_testing_with_spec_decoding:
dtype = "float16"
else:
dtype = "float32"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use float32 for everything? Would you still have to update tolerances for rocm if everything was float32?

Copy link
Copy Markdown
Collaborator Author

@AndreasKaratzas AndreasKaratzas Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TL;DR
ROCM_AITER_FA which is the only one that can satisfy the that particular subgroup in the test for ROCm, does not support fp32.

I think that, even for NVIDIA, the only attention backend that is also accurate enough for this test is FLEX_ATTENTION. That's the only attention backend that is more numerically accurate than others. Unfortunately, FLEX_ATTENTION backend is not fully supported on ROCm yet. So all other backends are going to inject more numerical error on any platform. Other than that, our backends, especially the AITER ones, are tested on float16, bfloat16, fp8, and in some cases fp4. So for now and for ROCm we are going to go with the float16 precision. In the future we will introduce a more accurate backend to better serve fp32 workloads.

Copy link
Copy Markdown
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fine temporary solution, but let's push on mlc-ai/xgrammar#476.

@mergify
Copy link
Copy Markdown

mergify bot commented Dec 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @AndreasKaratzas.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 9, 2025
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
@mergify mergify bot removed the needs-rebase label Dec 9, 2025
@mergify mergify bot added the multi-modality Related to multi-modality (#4194) label Dec 10, 2025
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) December 10, 2025 03:33
@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 10, 2025
@DarkLight1337 DarkLight1337 merged commit ed7af31 into vllm-project:main Dec 10, 2025
24 of 25 checks passed
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
…e test group (vllm-project#29358)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Co-authored-by: Micah Williamson <micah.williamson@amd.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
@AndreasKaratzas AndreasKaratzas deleted the akaratza_async_sched branch January 5, 2026 16:01
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…e test group (vllm-project#29358)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Co-authored-by: Micah Williamson <micah.williamson@amd.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants