[ROCm][Kernel] Add GFX11 (RDNA3/4) support for wvSplitK skinny GEMM kernels#34176
[ROCm][Kernel] Add GFX11 (RDNA3/4) support for wvSplitK skinny GEMM kernels#34176mgehre-amd wants to merge 1 commit intovllm-project:mainfrom
Conversation
…nels Enable the skinny GEMM wvSplitK kernels on all GFX11 (RDNA3) GPUs. Key adaptations for RDNA's wave32 execution model: - Use v_dot2acc_f32_f16 (renamed from v_dot2c_f32_f16 on GFX11) - Wave32 butterfly reduction via __shfl_xor instead of DPP row_shr - Use THRDS=32 on GFX11 (one wavefront per row) to avoid cross-wavefront reduction overhead; keep THRDS=64 on GFX9 - Add is_gfx11() runtime helper for host-side dispatch - Add on_gfx11() helper and use on_gfx9() or on_gfx11() to gate the Python dispatch path - Fix pre-existing test tolerance for large-K fp16 accumulation error Benchmarked on AMD Strix Halo (gfx1151, LPDDR5X-8000 128 GB): Model: Qwen/Qwen3-4B (bf16, unquantized) input-len=128, output-len=128, num-prompts=5: Median TTFT: 102.61 ms → 107.81 ms (no change, within noise) Median TPOT: 50.03 ms → 40.18 ms (19.7% faster) input-len=1920, output-len=128, num-prompts=5: Median TTFT: 875.22 ms → 872.81 ms (no change) Median TPOT: 51.39 ms → 41.37 ms (19.5% faster) Model: hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 input-len=128, output-len=128, num-prompts=5: Median TTFT: 199 ms → 195 ms (2% faster) Median TPOT: 172.22 ms → 169.49 ms (1.6% faster) input-len=1920, output-len=128, num-prompts=5: Median TTFT: 1487 ms → 1480 ms (0.5% faster) Median TPOT: 173.33 ms → 170.54 ms (1.6% faster) Direct kernel benchmarks (THRDS=32 vs THRDS=64 vs torch.nn.functional.linear): Layer N K M THRDS=64 THRDS=32 torch 32 vs torch qkv_proj 1 2560 3840 93 us 88 us 88 us 0.98x o_proj 1 2560 2560 58 us 56 us 70 us 1.20x gate_up 1 2560 19456 456 us 444 us 535 us 1.21x down_proj 1 9728 2560 235 us 221 us 264 us 1.19x lm_head(Qwen) 1 2560 151936 3504 us 3425 us 3623 us 1.06x lm_head(Llama) 1 4096 128256 4777 us 5056 us 7614 us 1.51x Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
8ae957e to
793c20c
Compare
There was a problem hiding this comment.
Code Review
This pull request successfully adds support for GFX11 (RDNA3/4) GPUs to the wvSplitK skinny GEMM kernels, which should provide a significant performance boost for decode-phase GEMMs on these architectures. The implementation correctly handles the instruction renaming and introduces a new wave32 reduction path. The Python dispatch logic and test tolerances are also updated accordingly. My main feedback is regarding significant code duplication in the CUDA/HIP kernels. I've added a high-severity comment suggesting a refactoring to improve long-term maintainability.
I am having trouble creating individual review comments. Click here to see my feedback.
csrc/rocm/skinny_gemms.cu (511-545)
This block of code for GFX11 reduction is substantially duplicated across three kernels: wvSplitK_hf_sml_, wvSplitK_hf_, and wvSplitK_hf_big_. This duplication negatively impacts maintainability, as any future changes to this logic will need to be manually applied in three different places, increasing the risk of introducing inconsistencies or bugs.
To address this, I recommend refactoring the GFX11-specific reduction and output logic into a shared __device__ function. This function could be parameterized to handle the minor differences between kernels, such as the conditional commit based on commitColumn.
Additionally, the GFX11-specific shared memory allocation logic is also duplicated and could be similarly refactored into a helper struct or macro to improve code reuse and clarity.
|
|
Summary
Enable the wvSplitK skinny GEMM kernels on all GFX11 (RDNA3) GPUs. These kernels accelerate decode-phase GEMMs (small batch size, large K) by using wave-level split-K with butterfly reduction.
Key adaptations for RDNA's wave32 execution model:
v_dot2acc_f32_f16(renamed fromv_dot2c_f32_f16on GFX11)__shfl_xorinstead of GFX9's DPProw_shrTHRDS=32on GFX11 (one wavefront per row); keepTHRDS=64on GFX9is_gfx11()runtime helper for host-side kernel dispatchon_gfx11()helper and useon_gfx9() or on_gfx11()to gate the Python dispatch pathsqrt(K)Performance
Benchmarked on AMD Strix Halo (gfx1151, LPDDR5X-8000 128 GB):
Qwen/Qwen3-4B (bf16, unquantized) — decode TPOT ~20% faster:
hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 — minimal impact (AWQ layers bypass skinny GEMM path):
TTFT (prefill) is unaffected as the skinny kernels only target batch_size ≤ 4 GEMMs.
Test plan
test_rocm_wvsplitk_kernel— all 22 parametrized cases pass on gfx1151test_rocm_wvsplitk_bias1D_kernel— all 22 parametrized cases passtest_rocm_wvsplitk_bias2D_kernel— all 22 parametrized cases pass