[ROCm][Perf] Enabled FP4Indexer for DSV4#42908
Conversation
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
There was a problem hiding this comment.
Code Review
This pull request introduces ROCm support for MXFP4 quantization within the DeepSeek V4 sparse indexer, adding specialized Triton kernels for paged MQA logits and implementing optimizations for trivial top-k scenarios. Feedback from the review identified high-severity issues in the new FP4 MQA kernels, specifically regarding shape mismatches in tl.dot_scaled operations that necessitate transposing the RHS scale tensor.
| scores = tl.dot_scaled( | ||
| q_packed, | ||
| q_scale, | ||
| "e2m1", | ||
| k_packed, | ||
| k_scale, | ||
| "e2m1", | ||
| lhs_k_pack=True, | ||
| rhs_k_pack=True, | ||
| out_dtype=tl.float32, | ||
| ) |
There was a problem hiding this comment.
The tl.dot_scaled operation expects the RHS scale tensor to have a shape of (K_scaled, N) when rhs_k_pack=True. In this kernel, k_scale is loaded with shape (BLOCK_KV, 4), which corresponds to (N, K_scaled). This mismatch will likely lead to incorrect results or compilation errors. You should transpose k_scale before passing it to tl.dot_scaled.
| scores = tl.dot_scaled( | |
| q_packed, | |
| q_scale, | |
| "e2m1", | |
| k_packed, | |
| k_scale, | |
| "e2m1", | |
| lhs_k_pack=True, | |
| rhs_k_pack=True, | |
| out_dtype=tl.float32, | |
| ) | |
| scores = tl.dot_scaled( | |
| q_packed, | |
| q_scale, | |
| "e2m1", | |
| k_packed, | |
| tl.trans(k_scale), | |
| "e2m1", | |
| lhs_k_pack=True, | |
| rhs_k_pack=True, | |
| out_dtype=tl.float32, | |
| ) |
| scores = tl.dot_scaled( | ||
| q_packed, | ||
| q_scale, | ||
| "e2m1", | ||
| k_packed, | ||
| k_scale, | ||
| "e2m1", | ||
| lhs_k_pack=True, | ||
| rhs_k_pack=True, | ||
| out_dtype=tl.float32, | ||
| ) |
There was a problem hiding this comment.
Similar to the paged kernel, tl.dot_scaled here expects the RHS scale to be (K_scaled, N). Since k_scale is loaded as (BLOCK_KV, 4), it needs to be transposed to match the expected (4, BLOCK_KV) shape.
| scores = tl.dot_scaled( | |
| q_packed, | |
| q_scale, | |
| "e2m1", | |
| k_packed, | |
| k_scale, | |
| "e2m1", | |
| lhs_k_pack=True, | |
| rhs_k_pack=True, | |
| out_dtype=tl.float32, | |
| ) | |
| scores = tl.dot_scaled( | |
| q_packed, | |
| q_scale, | |
| "e2m1", | |
| k_packed, | |
| tl.trans(k_scale), | |
| "e2m1", | |
| lhs_k_pack=True, | |
| rhs_k_pack=True, | |
| out_dtype=tl.float32, | |
| ) |
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.