Skip to content

[ROCm][CI] Reduce Flakiness For test_async_scheduling Using ROCM_ATTN With FP32#30811

Merged
tjtanaa merged 3 commits intovllm-project:mainfrom
ROCm:micah/async-scheduling-fix
Dec 18, 2025
Merged

[ROCm][CI] Reduce Flakiness For test_async_scheduling Using ROCM_ATTN With FP32#30811
tjtanaa merged 3 commits intovllm-project:mainfrom
ROCm:micah/async-scheduling-fix

Conversation

@micah-wil
Copy link
Copy Markdown
Contributor

@micah-wil micah-wil commented Dec 16, 2025

We've seen some flakiness in v1/e2e/test_async_scheduling.py::test_without_spec_decoding ever since the test was initially fixed in #29358. Ultimately we believe it was due to switching to fp16 for the test. Before, the test failed about 10%-20% of the time with the following error:

=========================== short test summary info ============================
FAILED v1/e2e/test_async_scheduling.py::test_without_spec_decoding - assert False
 +  where False = _all_logprobs_match([[{4913: Logprob(logprob=-0.7813263535499573, rank=1, decoded_token='{"') ......
================== 1 failed, 3 warnings in 158.04s (0:02:38) ===================

After switching back to dtype=float32 (which was the original setting before #29358, so to be clear this is not new), and using 'ROCM_ATTN', the test passes consistently and I believe the flakiness is addressed.

The test passed 35 times in a row locally with this PR, which I think is convincing enough that the flakiness is resolved.

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
@mergify mergify bot added rocm Related to AMD ROCm v1 labels Dec 16, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses flakiness in the test_async_scheduling E2E test on ROCm by switching the data type from float16 to float32 and updating the attention backend accordingly. The changes are logical and seem to effectively resolve the numerical precision issues that were causing the test to fail. I have one comment regarding an inconsistency in the supported data types metadata for the ROCM_ATTN backend that this change revealed. Overall, this is a good fix for improving CI stability.

m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change to ROCM_ATTN is necessary because the test now uses float32, which is not supported by ROCM_AITER_FA. However, I've found that ROCM_ATTN's backend class, RocmAttentionBackend, also doesn't declare support for float32 in its supported_dtypes list in vllm/v1/attention/backends/rocm_attn.py:

# vllm/v1/attention/backends/rocm_attn.py:155
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]

Since this test now passes, it implies that ROCM_ATTN does support float32. To maintain consistency and prevent potential issues, the supported_dtypes list for RocmAttentionBackend should be updated to include torch.float32. This could be addressed in a follow-up.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Looking into this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took the suggestion here, hopefully there is no problem with this as it does appear that fp32 works fine with this RocmAttentionBackend.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting catch. Can you pick a model and share the lm_eval score with this ROCM_ATTN backend and dtype=float32 ? And do you know who from AMD could help to quickly take a look at the Backend and understand if float32 is indeed compatible with ROCM_ATTN?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa Good suggestion, here are some lm_eval scores for Llama 3.2 8B, Mixtral 8x7B, and Mixtral 8x22B:

Test 1: ROCM_ATTN, float32, Llama 3.2 8B
lm_eval --model vllm \
    --model_args "pretrained=meta-llama/Llama-3.1-8B-Instruct,dtype=float32,attention_backend=ROCM_ATTN" \
    --tasks gsm8k --batch_size auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7779|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  |0.7566|±  |0.0118|


Test 2: ROCM_ATTN, float16, Llama 3.2 8B
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7763|±  |0.0115|
|     |       |strict-match    |     5|exact_match|↑  |0.7544|±  |0.0119|

Test 3: ROCM_ATTN, float32, Mixtral-8x7B-Instruct-v0.1
lm_eval --model vllm \
    --model_args "pretrained=mistralai/Mixtral-8x7B-Instructv0.1,dtype=float32,attention_backend=ROCM_ATTN,tensor_parallel_size=2" \
    --tasks gsm8k --batch_size auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6505|±  |0.0131|
|     |       |strict-match    |     5|exact_match|↑  |0.6459|±  |0.0132|

Test 4: ROCM_ATTN, float16, Mixtral-8x7B-Instruct-v0.1
lm_eval --model vllm \
    --model_args "pretrained=mistralai/Mixtral-8x7B-Instructv0.1,dtype=float16,attention_backend=ROCM_ATTN,tensor_parallel_size=2" \
    --tasks gsm8k --batch_size auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6505|±  |0.0131|
|     |       |strict-match    |     5|exact_match|↑  |0.6475|±  |0.0132|


Test 5: ROCM_ATTN, float32, Mixtral-8x22B-Instruct-v0.1
lm_eval --model vllm \
    --model_args "pretrained=mistralai/Mixtral-8x22B-v0.1,dtype=float32,attention_backend=ROCM_ATTN,tensor_parallel_size=4,tokenizer_mode=hf" \
    --tasks gsm8k --batch_size auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7597|±  |0.0118|
|     |       |strict-match    |     5|exact_match|↑  |0.7574|±  |0.0118|

Test 6: ROCM_ATTN, float16, Mixtral-8x22B-Instruct-v0.1
lm_eval --model vllm \
    --model_args "pretrained=mistralai/Mixtral-8x22B-v0.1,dtype=float16,attention_backend=ROCM_ATTN,tensor_parallel_size=4,tokenizer_mode=hf" \
    --tasks gsm8k --batch_size auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7604|±  |0.0118|
|     |       |strict-match    |     5|exact_match|↑  |0.7589|±  |0.0118|

There seems to be little difference between dtype=float32 and dtype=float16. I guess I am not sure that there should be much of a difference, the weights are just being cast from bfloat16 to float32. float32 helps with determinism it seems (the lack of which is what ultimately caused this test to fail). I will look closely at the ROCM_ATTN implementation to see if it supports float32, and if it does, why it was not marked as such.

Copy link
Copy Markdown
Contributor Author

@micah-wil micah-wil Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa Ok I have a better understanding of what's going on with ROCM_ATTN now (with some help from Greg). There is some kernel dispatching logic in chunked_prefill_paged_decode which uses use_rocm_custom_paged_attention to try to deploy ops.paged_attention_rocm for decoding. If it can not use this (i.e. if use_rocm_custom_paged_attention evaluates to False), it falls back to a Triton implementation kernel_paged_attention_2d. If we look inside of use_rocm_custom_paged_attention, we see that the conditional that it returns depends on qtype == torch.half or qtype == torch.bfloat16. So when using float32, use_rocm_custom_paged_attention evaluates to False and it falls back to kernel_paged_attention_2d. This Triton kernel can handle float32 just fine.

So the TLDR is that ROCM_ATTN does support float32. It was likely just an oversight that it wasn't marked as such before, the supported dtypes were added a bit after the transition to V1. I visually inspected some outputs as well to make sure that things look alright too.

I appreciate your feedback, please let me know if you have any more questions.

@micah-wil
Copy link
Copy Markdown
Contributor Author

Also see discussion here for for more context on this issue:
#30576 (comment)

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 17, 2025
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @micah-wil.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 17, 2025
…g-fix

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
@mergify mergify bot removed the needs-rebase label Dec 17, 2025
@njhill
Copy link
Copy Markdown
Member

njhill commented Dec 17, 2025

Hey, yes, for context due to precision-related fp ops differences with batching, variance is expected in outputs with fp16 and even more so with bf16, but it's much smaller with fp32 such that we can rely on deterministic results.

We need to enforce some kind of batch invariance with ground-truth based tests like this one.

@micah-wil
Copy link
Copy Markdown
Contributor Author

Hi @njhill, that's very helpful & makes sense to me. Thanks for the insight. Does this PR make sense/address the issue then?

Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @micah-wil @tjtanaa!

Could we also revert the rest of the changes made in #29358 now, (is_testing_with_spec_decoding etc.) and just use ROCM_ATTN rather than TRITON_ATTN in all cases?

Looks like this is a much better fix for what that was attempting to fix...

@micah-wil
Copy link
Copy Markdown
Contributor Author

@njhill I had considered that suggestion and I wish we could. However, if we do that, we get an error like the following for the test_with_spec_decoding case:

RuntimeError: Worker failed with error 'Unsupported attention metadata type for speculative decoding with num_speculative_tokens > 1: <class 'vllm.v1.attention.backends.rocm_attn.RocmAttentionMetadata'>. Supported types are: (<class 'vllm.v1.attention.backends.triton_attn.TritonAttentionMetadata'>, <class 'vllm.v1.attention.backends.flash_attn.FlashAttentionMetadata'>, <class 'vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionMetadata'>, <class 'vllm.v1.attention.backends.mla.common.MLACommonMetadata'>)', please check the stack trace above for the root cause

I am not sure why yet. Do you think I could address that in a follow up or does it need to be addressed in this PR?

Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, thanks ... sounds good to leave to possible follow-on then.

@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Dec 18, 2025

@njhill I had considered that suggestion and I wish we could. However, if we do that, we get an error like the following for the test_with_spec_decoding case:

RuntimeError: Worker failed with error 'Unsupported attention metadata type for speculative decoding with num_speculative_tokens > 1: <class 'vllm.v1.attention.backends.rocm_attn.RocmAttentionMetadata'>. Supported types are: (<class 'vllm.v1.attention.backends.triton_attn.TritonAttentionMetadata'>, <class 'vllm.v1.attention.backends.flash_attn.FlashAttentionMetadata'>, <class 'vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionMetadata'>, <class 'vllm.v1.attention.backends.mla.common.MLACommonMetadata'>)', please check the stack trace above for the root cause

I am not sure why yet. Do you think I could address that in a follow up or does it need to be addressed in this PR?

@micah-wil Yes. We need to follow up on this. ROCM_ATTN is a new attention that was split away from TRITON_ATTN, so we need to update the speculative decoding class like the eagle.py to include this ROCM_ATTN class metadata.

@tjtanaa tjtanaa merged commit fd8afdf into vllm-project:main Dec 18, 2025
49 of 51 checks passed
@micah-wil
Copy link
Copy Markdown
Contributor Author

@tjtanaa Ahh I see, I will look into that when I get time. Thanks for mentioning it, it's good to have that extra context.

yugong333 pushed a commit to yugong333/vllm that referenced this pull request Dec 22, 2025
… With FP32 (vllm-project#30811)

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
… With FP32 (vllm-project#30811)

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
… With FP32 (vllm-project#30811)

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
… With FP32 (vllm-project#30811)

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants