[WIP][Kernel] Generalized LL GEMMs with PDL#39897
Conversation
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a low-latency (LL) router GEMM implementation using cuteDSL (CUTLASS Python) to optimize MoE gate linear layers on SM90+ hardware for small batch sizes (M <= 16). The changes include new bf16 and fp8 kernels, a JIT-compiled Python wrapper, and integration into the GateLinear dispatch logic as a high-priority tier. Review feedback identifies two critical issues: the fp8 implementation will fail if the hidden dimension K is odd due to invalid int16 casting, and the kernels' use of 32-byte alignment for vectorized loads requires K to be a multiple of 16 to prevent incorrect data access or crashes during expert/token indexing.
| if is_fp8: | ||
| a_flat = hidden_states.view(torch.int16).reshape(-1) | ||
| b_flat = router_weight.view(torch.int16).reshape(-1) |
There was a problem hiding this comment.
The fp8 implementation will crash with a RuntimeError if K is odd because torch.view(torch.int16) requires the last dimension to be a multiple of 2. Additionally, K_eff = K // 2 on line 108 will cause the kernel to ignore the last element of each row, leading to incorrect results. The dispatch logic in gate_linear.py should be updated to ensure K is even for fp8 inputs, or the implementation should be adjusted to handle the remainder.
| bp = (gB.iterator + (n_idx * K_dim + kb)).align(32) | ||
| bt = cute.make_tensor(bp, cute.make_layout((VPT,))) | ||
| br = cute.make_rmem_tensor((VPT,), elem) | ||
| cute.autovec_copy(bt, br) | ||
| for m in cutlass.range_constexpr(M): | ||
| ap = (gA.iterator + (m * K_dim + kb)).align(32) |
There was a problem hiding this comment.
The kernel uses .align(32) for bf16 loads, which requires the hidden dimension K to be a multiple of 16 (32 bytes) so that every expert's row in the weight matrix gB and every token's row in gA is properly aligned. If K is not a multiple of 16, n_idx * K_dim (for n_idx > 0) or m * K_dim (for m > 0) will not be 32-byte aligned, leading to incorrect data being loaded or a crash. Similar alignment assumptions exist in the fp8 kernel. The dispatch logic should enforce K % 16 == 0 or the kernel should be made truly alignment-agnostic.
|
should we be using TVM FFI? see: Dao-AILab/flash-attention#2042 |
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
|
Hi @LopezCastroRoberto, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
| @dsl_user_op | ||
| def fused_fp8_mma_2n( | ||
| c0, | ||
| c1, | ||
| c2, | ||
| c3, | ||
| c4, | ||
| c5, | ||
| c6, | ||
| c7, | ||
| a0_lo, | ||
| a0_hi, | ||
| a1_lo, | ||
| a1_hi, | ||
| a2_lo, | ||
| a2_hi, | ||
| a3_lo, | ||
| a3_hi, | ||
| b0_lo, | ||
| b0_hi, | ||
| b1_lo, | ||
| b1_hi, | ||
| b2_lo, | ||
| b2_hi, | ||
| b3_lo, | ||
| b3_hi, | ||
| *, | ||
| loc=None, | ||
| ip=None, | ||
| ): |
There was a problem hiding this comment.
can we use cute primitives instead?
|
Hi @LopezCastroRoberto, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This is still WIP, but feedback is welcomed
Motivation
#38772
Benchmarks
LL BF16 kernels (nvidia/DeepSeek-V3.2-NVFP4)
MAIN:
PR:
TPOT improves ~8%
LL FP8 kernels (mistralai/Mistral-Medium-3.5-128B) -- Per-tensor scaling
MAIN:
PR:
TPOT improves ~5%