Skip to content

[ROCm][Kernel] Add GFX11 (RDNA3/4) support for wvSplitK skinny GEMM kernels#34176

Closed
mgehre-amd wants to merge 1 commit intovllm-project:mainfrom
mgehre-amd:matthias.skinny_gemm_gfx11
Closed

[ROCm][Kernel] Add GFX11 (RDNA3/4) support for wvSplitK skinny GEMM kernels#34176
mgehre-amd wants to merge 1 commit intovllm-project:mainfrom
mgehre-amd:matthias.skinny_gemm_gfx11

Conversation

@mgehre-amd
Copy link
Copy Markdown
Contributor

@mgehre-amd mgehre-amd commented Feb 9, 2026

Summary

Enable the wvSplitK skinny GEMM kernels on all GFX11 (RDNA3) GPUs. These kernels accelerate decode-phase GEMMs (small batch size, large K) by using wave-level split-K with butterfly reduction.

Key adaptations for RDNA's wave32 execution model:

  • Use v_dot2acc_f32_f16 (renamed from v_dot2c_f32_f16 on GFX11)
  • Wave32 butterfly reduction via __shfl_xor instead of GFX9's DPP row_shr
  • Use THRDS=32 on GFX11 (one wavefront per row); keep THRDS=64 on GFX9
  • Add is_gfx11() runtime helper for host-side kernel dispatch
  • Add on_gfx11() helper and use on_gfx9() or on_gfx11() to gate the Python dispatch path
  • Relax wvSplitK test tolerance to account for fp16 accumulation error scaling with sqrt(K)

Performance

Benchmarked on AMD Strix Halo (gfx1151, LPDDR5X-8000 128 GB):

Qwen/Qwen3-4B (bf16, unquantized) — decode TPOT ~20% faster:

Workload Metric Baseline Optimized Change
input=128, output=128 Median TPOT 50.03 ms 40.18 ms 19.7% faster
input=1920, output=128 Median TPOT 51.39 ms 41.37 ms 19.5% faster

hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 — minimal impact (AWQ layers bypass skinny GEMM path):

Workload Metric Baseline Optimized Change
input=128, output=128 Median TPOT 172.22 ms 169.49 ms 1.6% faster
input=1920, output=128 Median TPOT 173.33 ms 170.54 ms 1.6% faster

TTFT (prefill) is unaffected as the skinny kernels only target batch_size ≤ 4 GEMMs.

Test plan

  • test_rocm_wvsplitk_kernel — all 22 parametrized cases pass on gfx1151
  • test_rocm_wvsplitk_bias1D_kernel — all 22 parametrized cases pass
  • test_rocm_wvsplitk_bias2D_kernel — all 22 parametrized cases pass
  • Sanity check: model produces correct answers to math questions in both benchmarks

…nels

Enable the skinny GEMM wvSplitK kernels on all GFX11 (RDNA3) GPUs.
Key adaptations for RDNA's wave32 execution model:

- Use v_dot2acc_f32_f16 (renamed from v_dot2c_f32_f16 on GFX11)
- Wave32 butterfly reduction via __shfl_xor instead of DPP row_shr
- Use THRDS=32 on GFX11 (one wavefront per row) to avoid
  cross-wavefront reduction overhead; keep THRDS=64 on GFX9
- Add is_gfx11() runtime helper for host-side dispatch
- Add on_gfx11() helper and use on_gfx9() or on_gfx11() to gate
  the Python dispatch path
- Fix pre-existing test tolerance for large-K fp16 accumulation error

Benchmarked on AMD Strix Halo (gfx1151, LPDDR5X-8000 128 GB):

Model: Qwen/Qwen3-4B (bf16, unquantized)
input-len=128, output-len=128, num-prompts=5:
  Median TTFT: 102.61 ms → 107.81 ms (no change, within noise)
  Median TPOT: 50.03 ms → 40.18 ms (19.7% faster)

input-len=1920, output-len=128, num-prompts=5:
  Median TTFT: 875.22 ms → 872.81 ms (no change)
  Median TPOT: 51.39 ms → 41.37 ms (19.5% faster)

Model: hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4
input-len=128, output-len=128, num-prompts=5:
  Median TTFT: 199 ms → 195 ms (2% faster)
  Median TPOT: 172.22 ms → 169.49 ms (1.6% faster)

input-len=1920, output-len=128, num-prompts=5:
  Median TTFT: 1487 ms → 1480 ms (0.5% faster)
  Median TPOT: 173.33 ms → 170.54 ms (1.6% faster)

Direct kernel benchmarks (THRDS=32 vs THRDS=64 vs torch.nn.functional.linear):

Layer          N     K       M  THRDS=64  THRDS=32  torch   32 vs torch
qkv_proj       1  2560    3840     93 us     88 us   88 us     0.98x
o_proj         1  2560    2560     58 us     56 us   70 us     1.20x
gate_up        1  2560   19456    456 us    444 us  535 us     1.21x
down_proj      1  9728    2560    235 us    221 us  264 us     1.19x
lm_head(Qwen)  1  2560  151936   3504 us   3425 us 3623 us    1.06x
lm_head(Llama) 1  4096  128256   4777 us   5056 us 7614 us    1.51x

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd mgehre-amd force-pushed the matthias.skinny_gemm_gfx11 branch from 8ae957e to 793c20c Compare February 9, 2026 22:22
@mergify mergify bot added the rocm Related to AMD ROCm label Feb 9, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 9, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds support for GFX11 (RDNA3/4) GPUs to the wvSplitK skinny GEMM kernels, which should provide a significant performance boost for decode-phase GEMMs on these architectures. The implementation correctly handles the instruction renaming and introduces a new wave32 reduction path. The Python dispatch logic and test tolerances are also updated accordingly. My main feedback is regarding significant code duplication in the CUDA/HIP kernels. I've added a high-severity comment suggesting a refactoring to improve long-term maintainability.

I am having trouble creating individual review comments. Click here to see my feedback.

csrc/rocm/skinny_gemms.cu (511-545)

high

This block of code for GFX11 reduction is substantially duplicated across three kernels: wvSplitK_hf_sml_, wvSplitK_hf_, and wvSplitK_hf_big_. This duplication negatively impacts maintainability, as any future changes to this logic will need to be manually applied in three different places, increasing the risk of introducing inconsistencies or bugs.

To address this, I recommend refactoring the GFX11-specific reduction and output logic into a shared __device__ function. This function could be parameterized to handle the minor differences between kernels, such as the conditional commit based on commitColumn.

Additionally, the GFX11-specific shared memory allocation logic is also duplicated and could be similarly refactored into a helper struct or macro to improve code reuse and clarity.

@mgehre-amd mgehre-amd closed this Feb 9, 2026
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Feb 9, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 9, 2026

⚠️ The sha of the head commit of this PR conflicts with #34177. Mergify cannot evaluate rules on this PR. ⚠️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

1 participant