Skip to content

Conversation

@petercad
Copy link

@petercad petercad commented Oct 4, 2025

This PR updates FlashAttention to the new copy/MMA atoms.

Changes:

  • Prefill and decode unified into a single implementation, allowing simultaneous K and Q subgroup-level parallelization rather than an either-or.
  • GEMMs and softmax grouped together and the full k loop consolidated into an FMHA mainloop class.
    • This will facilitate further manual pipelining/overlap of GEMM with softmax.
  • Use new copy/MMA atoms and reorders to transparently support arbitrary data types.
  • Automatic copy/MMA operator selection.

Current status: prefill/decode examples working, similar/better performance to old examples.

Known issues:

  • Head size 192 decode config doesn't compile yet. Edit: fixed.
  • Strange SYCL compiler behavior/bug with tSrS->tArP reorder. Apparently the compiler believes there is UB somewhere and will omit a large section of the kernel as a result. For the moment, there's a direct copy as a workaround while I pin down the issue. I'm not able to reproduce this behavior with the reorder in isolation.

Additional features (causal masking, variable sequence lengths, etc.) to be added later.

Reminder: the new atoms require a very recent driver due to necessary IGC fixes/enhancements. Recommended version: ci-comp_igc-30613.

@petercad petercad changed the title [Umbrella commit] Re-implement FlashAttention with new Xe atoms Re-implement FlashAttention with new Xe atoms Oct 4, 2025
@petercad
Copy link
Author

petercad commented Oct 4, 2025

I will break up this large commit into self-contained smaller commits after review is complete.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this here? This isn't flash attention specific, is it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not. These started as some simple helpers to make copying to/from SLM easier for the epilogue. We could move them, maybe to include/cute/algorithm/cute.hpp, though they should be made more sophisticated (use smaller/larger block sizes as appropriate, automatic fallback to scatter/gather, etc.).

FragSRow k_rem_mask;
int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0];
for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) {
k_rem_mask(i) = (k < shape<0>(K_2D)) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the original S already contains NaN , fmin(NaN, NaN) = NaN, will propagates the NaN to softmax. This can corrupt row-wise sum and max leading to NaN in the final output O, could we have better k_rem_mask here to avoid this case?

Copy link
Author

@petercad petercad Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ClarkChin08 Can you explain your concern a bit more? If original S has a NaN value in bounds, then that indicates either an overflow from very badly scaled data or an inf/NaN input, and there's no safe way to numerically recover from that (we can't easily guess what the right value should be in place of that NaN). If S has a NaN value out of bounds, then the fmin with -inf will produce -inf, so the NaN will be removed and not corrupt the softmax value.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that if NaN appears in the valid range of S, it's likely a symptom of upstream issues like bad scaling or invalid inputs, and trying to "fix" it in the kernel can be tricky, especially in low-precision formats like fp8/fp4 where overflows are common.
Perhaps adding an optional debug mode to scan for NaNs/invalid inputs in S before softmax could help users identify issues early.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yes, that could be helpful. Perhaps this could take the form of a general helper function that scans for NaNs in an arbitrary tensor and aborts if any are found.

@ClarkChin08
Copy link

ClarkChin08 commented Oct 23, 2025

The following command encounters an accuracy issue (Disposition: Failed) with seq_len_kv=256

output [991]: 2.696791 vs -nan

./examples/06_bmg_flash_attention/06_xe_fmha_fwd_decode_hdim128 --iterations=10 --batch=1 --num_heads_q=8 --seq_len_kv=256 --seq_len_qo=1 --num_heads_kv=8

However, when seq_len_kv is changed to 512 or higher, the example passes successfully.

@petercad
Copy link
Author

petercad commented Oct 23, 2025

The following command encounters an accuracy issue (Disposition: Failed) with seq_len_kv=256

@ClarkChin08 I pushed a patch to fix issues like this earlier today. I double-checked your test case, and it's passing on my system; can you double-check with the latest commit?

@petercad petercad force-pushed the petercad/rearch_sdpa branch from af2f402 to 326669e Compare October 23, 2025 03:54
@ClarkChin08
Copy link

The following command encounters an accuracy issue (Disposition: Failed) with seq_len_kv=256

@ClarkChin08 I pushed a patch to fix issues like this earlier today. I double-checked your test case, and it's passing on my system; can you double-check with the latest commit?

Yes, passed now.

@petercad
Copy link
Author

Note: the CI is currently failing with compile-time divide-by-zero errors, but I can't reproduce the errors locally with any compiler/compile flags. If anyone can, let me know.

@petercad petercad force-pushed the petercad/rearch_sdpa branch from f767eb5 to 10b0c97 Compare October 27, 2025 21:56
@petercad
Copy link
Author

Note: the CI is currently failing with compile-time divide-by-zero errors, but I can't reproduce the errors locally with any compiler/compile flags. If anyone can, let me know.

Didn't realize CI was merging branches into main prior to testing. Thanks to @rolandschulz for helping figure this out.

Branch is rebased now and split into a logical set of patches.

@petercad petercad force-pushed the petercad/rearch_sdpa branch 2 times, most recently from b0e30f4 to 7dd479b Compare October 27, 2025 23:19
@tdeng5 tdeng5 added the release label Oct 28, 2025
@petercad petercad force-pushed the petercad/rearch_sdpa branch from 7dd479b to 460d34a Compare October 28, 2025 15:37
auto _0E0 = ScaledBasis<C<0>,0>{};
auto flayout = filter(flatten(layout));
return inner_product_atuple_max(shape(flayout), stride(flayout));
auto coshape = inner_product_atuple_max(shape(flayout), stride(flayout)) + _0E0 + _0E0;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you add 0 twice?

Copy link
Author

@petercad petercad Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a trick to ensure we get a tuple type for coshape. In case of a trivial layout (_1:_0), inner_product_atuple_max returns a number. Adding 0@0 (_0E0) makes it a ScaledBasis type, and then adding 0@0 again makes a tuple (0). In general, inner_product_atuple_max is already returning a tuple, and then adding 0@0 has no effect.


const int k_blocks = cute::ceil_div(s.seq_len_kv, get<1>(TileShapeQK{}));

auto shape_Q = make_shape(s.seq_len_qo, s.head_size_qk, s.num_heads_q, s.batch);
Copy link

@sunjiweiswift sunjiweiswift Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current layout is HND{batch, num_head, seq_len, head_size}. This format is not used in vllm and sglang; instead, NHD's(batch, seq_len, num_head, head_size) format is used.
I would like to provide an example of NHD support. Because NHD is the final layout used.

Copy link
Author

@petercad petercad Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only a minor tweak is needed for NHD support -- you would keep the shapes and kernel the same, and set the strides appropriately on this line. I added a comment there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request release urgent PR requires a urgent attention (for release or blocking another PR)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants