[ROCm][Kernel] Add GFX11 (RDNA3) support for wvSplitK skinny GEMM kernels#34177
Closed
mgehre-amd wants to merge 1 commit intovllm-project:mainfrom
Closed
[ROCm][Kernel] Add GFX11 (RDNA3) support for wvSplitK skinny GEMM kernels#34177mgehre-amd wants to merge 1 commit intovllm-project:mainfrom
mgehre-amd wants to merge 1 commit intovllm-project:mainfrom
Conversation
4 tasks
Contributor
There was a problem hiding this comment.
Code Review
This pull request adds support for GFX11 (RDNA3) GPUs to the wvSplitK skinny GEMM kernels, which shows significant performance improvements for decode-phase GEMMs. The changes include adapting to wave32 execution, using new instructions, and updating dispatch logic on both host and device sides. The implementation is mostly correct, but I've identified a few areas for improvement regarding thread-safety and code duplication in the CUDA kernels, which would enhance robustness and maintainability.
…nels Enable the skinny GEMM wvSplitK kernels on all GFX11 (RDNA3) GPUs. Key adaptations for RDNA's wave32 execution model: - Use v_dot2acc_f32_f16 (renamed from v_dot2c_f32_f16 on GFX11) - Wave32 butterfly reduction via __shfl_xor instead of DPP row_shr - Use THRDS=32 on GFX11 (one wavefront per row) to avoid cross-wavefront reduction overhead; keep THRDS=64 on GFX9 - Add is_gfx11() runtime helper for host-side dispatch - Add on_gfx11() helper and use on_gfx9() or on_gfx11() to gate the Python dispatch path - Fix pre-existing test tolerance for large-K fp16 accumulation error Benchmarked on AMD Strix Halo (gfx1151, LPDDR5X-8000 128 GB): Model: Qwen/Qwen3-4B (bf16, unquantized) input-len=128, output-len=128, num-prompts=5: Median TTFT: 102.61 ms → 107.81 ms (no change, within noise) Median TPOT: 50.03 ms → 40.18 ms (19.7% faster) input-len=1920, output-len=128, num-prompts=5: Median TTFT: 875.22 ms → 872.81 ms (no change) Median TPOT: 51.39 ms → 41.37 ms (19.5% faster) Model: hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 input-len=128, output-len=128, num-prompts=5: Median TTFT: 199 ms → 195 ms (2% faster) Median TPOT: 172.22 ms → 169.49 ms (1.6% faster) input-len=1920, output-len=128, num-prompts=5: Median TTFT: 1487 ms → 1480 ms (0.5% faster) Median TPOT: 173.33 ms → 170.54 ms (1.6% faster) Direct kernel benchmarks (THRDS=32 vs THRDS=64 vs torch.nn.functional.linear): Layer N K M THRDS=64 THRDS=32 torch 32 vs torch qkv_proj 1 2560 3840 93 us 88 us 88 us 0.98x o_proj 1 2560 2560 58 us 56 us 70 us 1.20x gate_up 1 2560 19456 456 us 444 us 535 us 1.21x down_proj 1 9728 2560 235 us 221 us 264 us 1.19x lm_head(Qwen) 1 2560 151936 3504 us 3425 us 3623 us 1.06x lm_head(Llama) 1 4096 128256 4777 us 5056 us 7614 us 1.51x Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
793c20c to
ff5a327
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
4 tasks
Contributor
Author
|
Closed in favor of #34709 |
laudney
pushed a commit
to mmonad/vllm
that referenced
this pull request
Feb 20, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to __HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous guard let gfx11 fall through to the MFMA #else path it can't execute. Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177). Signed-off-by: L.B.R. <lbr@mmonad.com>
laudney
pushed a commit
to mmonad/vllm
that referenced
this pull request
Feb 20, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to __HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous guard let gfx11 fall through to the MFMA #else path it can't execute. Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177). Signed-off-by: L.B.R. <lbr@mmonad.com>
laudney
pushed a commit
to mmonad/vllm
that referenced
this pull request
Feb 25, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to __HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous guard let gfx11 fall through to the MFMA #else path it can't execute. Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177). Signed-off-by: L.B.R. <lbr@mmonad.com>
laudney
pushed a commit
to mmonad/vllm
that referenced
this pull request
Feb 25, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to __HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous guard let gfx11 fall through to the MFMA #else path it can't execute. Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177). Signed-off-by: L.B.R. <lbr@mmonad.com>
laudney
pushed a commit
to mmonad/vllm
that referenced
this pull request
Mar 7, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to __HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous guard let gfx11 fall through to the MFMA #else path it can't execute. Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177). Signed-off-by: L.B.R. <lbr@mmonad.com>
laudney
pushed a commit
to mmonad/vllm
that referenced
this pull request
Mar 11, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to __HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous guard let gfx11 fall through to the MFMA #else path it can't execute. Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177). Signed-off-by: L.B.R. <lbr@mmonad.com>
laudney
pushed a commit
to mmonad/vllm
that referenced
this pull request
Mar 19, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to __HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous guard let gfx11 fall through to the MFMA #else path it can't execute. Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177). Signed-off-by: L.B.R. <lbr@mmonad.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Enable the wvSplitK skinny GEMM kernels on all GFX11 (RDNA3) GPUs. These kernels accelerate decode-phase GEMMs (small batch size, large K) by using wave-level split-K with butterfly reduction.
Key adaptations for RDNA's wave32 execution model:
v_dot2acc_f32_f16(renamed fromv_dot2c_f32_f16on GFX11)__shfl_xorinstead of GFX9's DPProw_shrTHRDS=32on GFX11 (one wavefront per row); keepTHRDS=64on GFX9is_gfx11()runtime helper for host-side kernel dispatchon_gfx11()helper and useon_gfx9() or on_gfx11()to gate the Python dispatch pathsqrt(K)Performance
Benchmarked on AMD Strix Halo (gfx1151, LPDDR5X-8000 128 GB):
Qwen/Qwen3-4B (bf16, unquantized) — decode TPOT ~20% faster:
hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 — minimal impact (AWQ layers bypass skinny GEMM path):
TTFT (prefill) is unaffected as the skinny kernels only target batch_size ≤ 4 GEMMs.
Test plan
test_rocm_wvsplitk_kernel— all 22 parametrized cases pass on gfx1151test_rocm_wvsplitk_bias1D_kernel— all 22 parametrized cases passtest_rocm_wvsplitk_bias2D_kernel— all 22 parametrized cases pass