[CuTe,Fwd,SM90] Enable head dim 512 for SM90#2422
[CuTe,Fwd,SM90] Enable head dim 512 for SM90#2422IwakuraRein wants to merge 2 commits intoDao-AILab:mainfrom
Conversation
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
For hdim 512 we probably want 2 warp groups where WG0 computes the Q @ K, softmax, then write to smem. Then both warp groups compute P @ V. That's how the FA3 implementation did it. As written this PR would have 2 WGs both computing the same Q @ K and softmax so there's redundancy here? |
|
@tridao Thanks for the suggestion. Currently it's still using 1 WG, and the performance in decode is bad. Will optimize based on the |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
Hi @tridao, thanks again for your reply. Could you clarify why using 1 warp group for QK is preferred? Is it mainly to reduce register pressure? From my understanding, using 1 warp group for QK and 2 for KV introduces an extra shared memory round trip. It also requires synchronization after QK, which seems to be conflict with ping-pong scheduling. In my tests, using 2 warp groups for both QK and KV didn’t cause register spilling. I'd really appreciate any insight you can share. Thanks! |
|
How do you split the work when using 2 WGs to compute QK? |
@tridao I am splitting along the N dimension. I think QK is (64, 128, 16), and VP is (64, 512, 16). Therefore, the MNK of the MMA for each warpgroup is (64, 64, 16) for QK, and (64, 256, 16) for KV. The generated SASS confirmed this. |
|
What's the M, N, K dimension for the QK MMA? |
Relax the checks for head dim for SM90; Fix the N size of MMA and the tensor shape of
PagedKVManager.