-
Notifications
You must be signed in to change notification settings - Fork 64
Re-implement FlashAttention with new Xe atoms #547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
I will break up this large commit into self-contained smaller commits after review is complete. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this here? This isn't flash attention specific, is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's not. These started as some simple helpers to make copying to/from SLM easier for the epilogue. We could move them, maybe to include/cute/algorithm/cute.hpp, though they should be made more sophisticated (use smaller/larger block sizes as appropriate, automatic fallback to scatter/gather, etc.).
| FragSRow k_rem_mask; | ||
| int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0]; | ||
| for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) { | ||
| k_rem_mask(i) = (k < shape<0>(K_2D)) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the original S already contains NaN , fmin(NaN, NaN) = NaN, will propagates the NaN to softmax. This can corrupt row-wise sum and max leading to NaN in the final output O, could we have better k_rem_mask here to avoid this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ClarkChin08 Can you explain your concern a bit more? If original S has a NaN value in bounds, then that indicates either an overflow from very badly scaled data or an inf/NaN input, and there's no safe way to numerically recover from that (we can't easily guess what the right value should be in place of that NaN). If S has a NaN value out of bounds, then the fmin with -inf will produce -inf, so the NaN will be removed and not corrupt the softmax value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that if NaN appears in the valid range of S, it's likely a symptom of upstream issues like bad scaling or invalid inputs, and trying to "fix" it in the kernel can be tricky, especially in low-precision formats like fp8/fp4 where overflows are common.
Perhaps adding an optional debug mode to scan for NaNs/invalid inputs in S before softmax could help users identify issues early.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, yes, that could be helpful. Perhaps this could take the form of a general helper function that scans for NaNs in an arbitrary tensor and aborts if any are found.
|
The following command encounters an accuracy issue (Disposition: Failed) with seq_len_kv=256
However, when seq_len_kv is changed to 512 or higher, the example passes successfully. |
@ClarkChin08 I pushed a patch to fix issues like this earlier today. I double-checked your test case, and it's passing on my system; can you double-check with the latest commit? |
af2f402 to
326669e
Compare
Yes, passed now. |
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp
Outdated
Show resolved
Hide resolved
|
Note: the CI is currently failing with compile-time divide-by-zero errors, but I can't reproduce the errors locally with any compiler/compile flags. If anyone can, let me know. |
f767eb5 to
10b0c97
Compare
Didn't realize CI was merging branches into main prior to testing. Thanks to @rolandschulz for helping figure this out. Branch is rebased now and split into a logical set of patches. |
b0e30f4 to
7dd479b
Compare
7dd479b to
460d34a
Compare
| auto _0E0 = ScaledBasis<C<0>,0>{}; | ||
| auto flayout = filter(flatten(layout)); | ||
| return inner_product_atuple_max(shape(flayout), stride(flayout)); | ||
| auto coshape = inner_product_atuple_max(shape(flayout), stride(flayout)) + _0E0 + _0E0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you add 0 twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a trick to ensure we get a tuple type for coshape. In case of a trivial layout (_1:_0), inner_product_atuple_max returns a number. Adding 0@0 (_0E0) makes it a ScaledBasis type, and then adding 0@0 again makes a tuple (0). In general, inner_product_atuple_max is already returning a tuple, and then adding 0@0 has no effect.
|
|
||
| const int k_blocks = cute::ceil_div(s.seq_len_kv, get<1>(TileShapeQK{})); | ||
|
|
||
| auto shape_Q = make_shape(s.seq_len_qo, s.head_size_qk, s.num_heads_q, s.batch); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current layout is HND{batch, num_head, seq_len, head_size}. This format is not used in vllm and sglang; instead, NHD's(batch, seq_len, num_head, head_size) format is used.
I would like to provide an example of NHD support. Because NHD is the final layout used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only a minor tweak is needed for NHD support -- you would keep the shapes and kernel the same, and set the strides appropriately on this line. I added a comment there.
This PR updates FlashAttention to the new copy/MMA atoms.
Changes:
Current status: prefill/decode examples working, similar/better performance to old examples.
Known issues:
Additional features (causal masking, variable sequence lengths, etc.) to be added later.
Reminder: the new atoms require a very recent driver due to necessary IGC fixes/enhancements. Recommended version: ci-comp_igc-30613.