Skip to content

[CuTe,Fwd,SM90] Enable head dim 512 for SM90#2422

Open
IwakuraRein wants to merge 2 commits intoDao-AILab:mainfrom
IwakuraRein:enable-hdim-512-sm90
Open

[CuTe,Fwd,SM90] Enable head dim 512 for SM90#2422
IwakuraRein wants to merge 2 commits intoDao-AILab:mainfrom
IwakuraRein:enable-hdim-512-sm90

Conversation

@IwakuraRein
Copy link
Copy Markdown

Relax the checks for head dim for SM90; Fix the N size of MMA and the tensor shape of PagedKVManager.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@tridao
Copy link
Copy Markdown
Member

tridao commented Apr 1, 2026

For hdim 512 we probably want 2 warp groups where WG0 computes the Q @ K, softmax, then write to smem. Then both warp groups compute P @ V. That's how the FA3 implementation did it.
https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp

As written this PR would have 2 WGs both computing the same Q @ K and softmax so there's redundancy here?

@IwakuraRein
Copy link
Copy Markdown
Author

@tridao Thanks for the suggestion. Currently it's still using 1 WG, and the performance in decode is bad. Will optimize based on the LargeHeadDimV path in the FA3.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
IwakuraRein added a commit to IwakuraRein/flash-attention that referenced this pull request Apr 2, 2026
IwakuraRein added a commit to IwakuraRein/flash-attention that referenced this pull request Apr 2, 2026
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein
Copy link
Copy Markdown
Author

Hi @tridao, thanks again for your reply.

Could you clarify why using 1 warp group for QK is preferred? Is it mainly to reduce register pressure? From my understanding, using 1 warp group for QK and 2 for KV introduces an extra shared memory round trip. It also requires synchronization after QK, which seems to be conflict with ping-pong scheduling. In my tests, using 2 warp groups for both QK and KV didn’t cause register spilling.

I'd really appreciate any insight you can share. Thanks!

@tridao
Copy link
Copy Markdown
Member

tridao commented Apr 3, 2026

How do you split the work when using 2 WGs to compute QK?
this is a 64x 512 @ 64 x 512 MMA right?
Are you spliting the work along the N (64) dimension? Or do both WGs compute the same 64 x 64 output?

@IwakuraRein
Copy link
Copy Markdown
Author

How do you split the work when using 2 WGs to compute QK? this is a 64x 512 @ 64 x 512 MMA right? Are you spliting the work along the N (64) dimension? Or do both WGs compute the same 64 x 64 output?

@tridao I am splitting along the N dimension. I think QK is (64, 128, 16), and VP is (64, 512, 16). Therefore, the MNK of the MMA for each warpgroup is (64, 64, 16) for QK, and (64, 256, 16) for KV. The generated SASS confirmed this.

@tridao
Copy link
Copy Markdown
Member

tridao commented Apr 3, 2026

What's the M, N, K dimension for the QK MMA?
Shoudn't K be 512 if Q, K have headdim 512?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants