Skip to content

[ROCm][Kernel] Add GFX11 (RDNA3) support for wvSplitK skinny GEMM kernels#34177

Closed
mgehre-amd wants to merge 1 commit intovllm-project:mainfrom
mgehre-amd:matthias.skinny_gemm_gfx11
Closed

[ROCm][Kernel] Add GFX11 (RDNA3) support for wvSplitK skinny GEMM kernels#34177
mgehre-amd wants to merge 1 commit intovllm-project:mainfrom
mgehre-amd:matthias.skinny_gemm_gfx11

Conversation

@mgehre-amd
Copy link
Copy Markdown
Contributor

@mgehre-amd mgehre-amd commented Feb 9, 2026

Summary

Enable the wvSplitK skinny GEMM kernels on all GFX11 (RDNA3) GPUs. These kernels accelerate decode-phase GEMMs (small batch size, large K) by using wave-level split-K with butterfly reduction.

Key adaptations for RDNA's wave32 execution model:

  • Use v_dot2acc_f32_f16 (renamed from v_dot2c_f32_f16 on GFX11)
  • Wave32 butterfly reduction via __shfl_xor instead of GFX9's DPP row_shr
  • Use THRDS=32 on GFX11 (one wavefront per row); keep THRDS=64 on GFX9
  • Add is_gfx11() runtime helper for host-side kernel dispatch
  • Add on_gfx11() helper and use on_gfx9() or on_gfx11() to gate the Python dispatch path
  • Relax wvSplitK test tolerance to account for fp16 accumulation error scaling with sqrt(K)

Performance

Benchmarked on AMD Strix Halo (gfx1151, LPDDR5X-8000 128 GB):

Qwen/Qwen3-4B (bf16, unquantized) — decode TPOT ~20% faster:

Workload Metric Baseline Optimized Change
input=128, output=128 Median TPOT 50.03 ms 40.18 ms 19.7% faster
input=1920, output=128 Median TPOT 51.39 ms 41.37 ms 19.5% faster

hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 — minimal impact (AWQ layers bypass skinny GEMM path):

Workload Metric Baseline Optimized Change
input=128, output=128 Median TPOT 172.22 ms 169.49 ms 1.6% faster
input=1920, output=128 Median TPOT 173.33 ms 170.54 ms 1.6% faster

TTFT (prefill) is unaffected as the skinny kernels only target batch_size ≤ 4 GEMMs.

Test plan

  • test_rocm_wvsplitk_kernel — all 22 parametrized cases pass on gfx1151
  • test_rocm_wvsplitk_bias1D_kernel — all 22 parametrized cases pass
  • test_rocm_wvsplitk_bias2D_kernel — all 22 parametrized cases pass
  • Sanity check: model produces correct answers to math questions in both benchmarks

@mergify mergify bot added the rocm Related to AMD ROCm label Feb 9, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 9, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for GFX11 (RDNA3) GPUs to the wvSplitK skinny GEMM kernels, which shows significant performance improvements for decode-phase GEMMs. The changes include adapting to wave32 execution, using new instructions, and updating dispatch logic on both host and device sides. The implementation is mostly correct, but I've identified a few areas for improvement regarding thread-safety and code duplication in the CUDA kernels, which would enhance robustness and maintainability.

…nels

Enable the skinny GEMM wvSplitK kernels on all GFX11 (RDNA3) GPUs.
Key adaptations for RDNA's wave32 execution model:

- Use v_dot2acc_f32_f16 (renamed from v_dot2c_f32_f16 on GFX11)
- Wave32 butterfly reduction via __shfl_xor instead of DPP row_shr
- Use THRDS=32 on GFX11 (one wavefront per row) to avoid
  cross-wavefront reduction overhead; keep THRDS=64 on GFX9
- Add is_gfx11() runtime helper for host-side dispatch
- Add on_gfx11() helper and use on_gfx9() or on_gfx11() to gate
  the Python dispatch path
- Fix pre-existing test tolerance for large-K fp16 accumulation error

Benchmarked on AMD Strix Halo (gfx1151, LPDDR5X-8000 128 GB):

Model: Qwen/Qwen3-4B (bf16, unquantized)
input-len=128, output-len=128, num-prompts=5:
  Median TTFT: 102.61 ms → 107.81 ms (no change, within noise)
  Median TPOT: 50.03 ms → 40.18 ms (19.7% faster)

input-len=1920, output-len=128, num-prompts=5:
  Median TTFT: 875.22 ms → 872.81 ms (no change)
  Median TPOT: 51.39 ms → 41.37 ms (19.5% faster)

Model: hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4
input-len=128, output-len=128, num-prompts=5:
  Median TTFT: 199 ms → 195 ms (2% faster)
  Median TPOT: 172.22 ms → 169.49 ms (1.6% faster)

input-len=1920, output-len=128, num-prompts=5:
  Median TTFT: 1487 ms → 1480 ms (0.5% faster)
  Median TPOT: 173.33 ms → 170.54 ms (1.6% faster)

Direct kernel benchmarks (THRDS=32 vs THRDS=64 vs torch.nn.functional.linear):

Layer          N     K       M  THRDS=64  THRDS=32  torch   32 vs torch
qkv_proj       1  2560    3840     93 us     88 us   88 us     0.98x
o_proj         1  2560    2560     58 us     56 us   70 us     1.20x
gate_up        1  2560   19456    456 us    444 us  535 us     1.21x
down_proj      1  9728    2560    235 us    221 us  264 us     1.19x
lm_head(Qwen)  1  2560  151936   3504 us   3425 us 3623 us    1.06x
lm_head(Llama) 1  4096  128256   4777 us   5056 us 7614 us    1.51x

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd mgehre-amd force-pushed the matthias.skinny_gemm_gfx11 branch from 793c20c to ff5a327 Compare February 9, 2026 22:31
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 10, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mgehre-amd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mgehre-amd
Copy link
Copy Markdown
Contributor Author

Closed in favor of #34709

laudney pushed a commit to mmonad/vllm that referenced this pull request Feb 20, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to
__HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous
guard let gfx11 fall through to the MFMA #else path it can't execute.

Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly
model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177).

Signed-off-by: L.B.R. <lbr@mmonad.com>
laudney pushed a commit to mmonad/vllm that referenced this pull request Feb 20, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to
__HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous
guard let gfx11 fall through to the MFMA #else path it can't execute.

Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly
model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177).

Signed-off-by: L.B.R. <lbr@mmonad.com>
laudney pushed a commit to mmonad/vllm that referenced this pull request Feb 25, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to
__HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous
guard let gfx11 fall through to the MFMA #else path it can't execute.

Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly
model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177).

Signed-off-by: L.B.R. <lbr@mmonad.com>
laudney pushed a commit to mmonad/vllm that referenced this pull request Feb 25, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to
__HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous
guard let gfx11 fall through to the MFMA #else path it can't execute.

Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly
model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177).

Signed-off-by: L.B.R. <lbr@mmonad.com>
laudney pushed a commit to mmonad/vllm that referenced this pull request Mar 7, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to
__HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous
guard let gfx11 fall through to the MFMA #else path it can't execute.

Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly
model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177).

Signed-off-by: L.B.R. <lbr@mmonad.com>
laudney pushed a commit to mmonad/vllm that referenced this pull request Mar 11, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to
__HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous
guard let gfx11 fall through to the MFMA #else path it can't execute.

Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly
model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177).

Signed-off-by: L.B.R. <lbr@mmonad.com>
laudney pushed a commit to mmonad/vllm that referenced this pull request Mar 19, 2026
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to
__HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous
guard let gfx11 fall through to the MFMA #else path it can't execute.

Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly
model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177).

Signed-off-by: L.B.R. <lbr@mmonad.com>
@mgehre-amd mgehre-amd closed this Mar 20, 2026
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

1 participant