Skip to content

[ROCm] Enable wvSplitK skinny GEMM kernel for RDNA4/gfx1x decode#34709

Merged
gshtras merged 9 commits intovllm-project:mainfrom
mmonad:feat/rocm-rdna4-wvsplitk
Mar 20, 2026
Merged

[ROCm] Enable wvSplitK skinny GEMM kernel for RDNA4/gfx1x decode#34709
gshtras merged 9 commits intovllm-project:mainfrom
mmonad:feat/rocm-rdna4-wvsplitk

Conversation

@laudney
Copy link
Copy Markdown
Contributor

@laudney laudney commented Feb 17, 2026

Summary

Enable the wvSplitK and wvSplitKQ skinny GEMM kernels on RDNA (gfx11/gfx12) hardware for decode-phase GEMMs (M=1..4). Previously these kernels were gfx9-only (MI-series), with RDNA falling through to torch.nn.functional.linear (unquantized) or generic torch._scaled_mm (FP8).

~15% decode token/s improvement on AMD Radeon AI PRO R9700 (gfx1201).

Commit 1: wvSplitK — unquantized BF16/FP16 skinny GEMM

Kernel changes (csrc/rocm/skinny_gemms.cu)

  • __HIP__GFX1X__ compile-time macro for gfx1100/1101/1150/1151/1200/1201
  • on_gfx1x() runtime detection for wave32 vs wave64 dispatch
  • Parameterized launch macros (WVSPLITK_CFG/WVSPLIT_TILE_CFG) accept THRDS and WvPrGrp to support both wave64 (gfx9) and wave32 (RDNA)
  • Wave32 DPP reduction: Replace gfx9-only cross-row instructions (wave_shr:1, row_bcast:15, row_bcast:31) with RDNA wave32 alternative (row_shr:1 + __shfl_xor(val, 16))
  • gfx1x DOT2C: v_dot2_f32_f16 inline asm (VOP3-P encoding, dot10-insts) instead of gfx9's v_dot2c_f32_f16 (VOP2)
  • MFMA guards: #ifdef __HIP__GFX9__ around MFMA bf16 paths (prevents Clang from validating gfx9-only inline asm in dead if constexpr branches)
  • device_guard moved before tensor allocations in wvSplitKrc

Python dispatch changes (vllm/model_executor/layers/utils.py)

  • use_skinny condition now includes on_gfx1x() alongside on_gfx9()
  • Removes the prior LLMM1-only gfx1x workaround (use_skinny_llmm1_gfx1x)
  • gfx1x now shares the same wvSplitKLLMM1 dispatch path as gfx9

Commit 2: wvSplitKQ — FP8 per-tensor skinny GEMM

Kernel changes (csrc/rocm/skinny_gemms.cu)

  • __HIP__GFX12__ compile-time macro for gfx1200/gfx1201
  • on_gfx12() runtime detection for FP8-capable RDNA dispatch
  • Both wvSplitKQ_hf_sml_ and wvSplitKQ_hf_ kernels gain a gfx1x path:
    • Scalar float accumulator (vs floatx16 MFMA vector on gfx9)
    • __builtin_amdgcn_dot4_f32_fp8_fp8() for 4-way FP8 dot products — hardware interprets FP8 bits per target arch (fnuz on gfx9, fn on gfx12)
    • DPP wave32 reduction: row_shr:8/4/2/1 + __shfl_xor(val, 16)
    • Writer lane: threadIdx.x == (THRDS - 1) works for both wave sizes
  • Host dispatch: WVSPLITKQ macro gains _THRDS parameter, runtime on_gfx1x() selects THRDS=32/WvPrGrp=16 for wave32

Related PRs (RDNA4/gfx12 series)

Test plan

  • Qwen3-Coder-30B-A3B AWQ-4bit inference on gfx1201 — correct output, ~15% decode speedup
  • Qwen3-14B-FP8 inference on gfx1201 — FP8 path functional (note: block-wise FP8 model uses Triton kernel, per-tensor FP8 would use wvSplitKQ)
  • No regression on gfx9/MI-series (kernel guards preserve original wave64 paths)
  • Unit tests: test_rocm_wvsplitk_kernel, test_rocm_unquantized_gemm dispatch tests

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant contribution that enables the wvSplitK skinny GEMM kernel for RDNA4/gfx1x architectures on ROCm, which should provide a notable performance improvement for decode-phase GEMMs. The changes are comprehensive, spanning from the CUDA kernel implementation to the Python dispatch logic, dependency management, and testing. The kernel modifications are well-structured, using preprocessor directives to handle different GPU architectures and parameterizing macros for wave32/64 support. The Python code updates correctly integrate the new kernel path and include a bug fix in the dispatch logic. The addition of new tests for the added functionality is also a great touch. I have one suggestion to improve the thread safety of a newly added utility function.

Comment on lines +57 to +69
bool on_gfx1x() {
static bool is_cached = false;
static bool result = false;
if (!is_cached) {
auto dprops = at::cuda::getCurrentDeviceProperties();
std::string device_arch = dprops->gcnArchName;
result =
device_arch.find("gfx11") != std::string::npos ||
device_arch.find("gfx12") != std::string::npos;
is_cached = true;
}
return result;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of on_gfx1x is not thread-safe. If multiple threads call this function for the first time concurrently, it can lead to a data race on the is_cached and result static variables. While this might be a benign race in many cases, it's better to avoid undefined behavior.

You can make this function thread-safe and more concise by using a lambda to initialize a static const variable. C++11 and later guarantee that the initialization of local static variables is thread-safe.

bool on_gfx1x() {
  static const bool result = [] {
    const auto* dprops = at::cuda::getCurrentDeviceProperties();
    const std::string device_arch = dprops->gcnArchName;
    return device_arch.find("gfx11") != std::string::npos ||
           device_arch.find("gfx12") != std::string::npos;
  }();
  return result;
}

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 17, 2026

Hi @laudney, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 17, 2026

Hi @laudney, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 17, 2026

Hi @laudney, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@laudney laudney force-pushed the feat/rocm-rdna4-wvsplitk branch from afbdd14 to 242cddf Compare February 17, 2026 20:41
@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

@amd-hhashemi Could you please take a look into this? This passes the AMD CI tests with no new regressions at least. But I would like a more experienced eye into this if possible :)

@amd-hhashemi
Copy link
Copy Markdown
Contributor

amd-hhashemi commented Feb 19, 2026

Is this related to #34177

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

Is this related to #34177

No, you were the only person in my head that I know to have knowledge on the subject thus the ability to also review this. Feel free to forward it to someone else if you don't have the time of course 😅

@amd-hhashemi
Copy link
Copy Markdown
Contributor

Sorry I gave the wrong link. I meant #34177 which is from Mathias. It also targets skinny gemm on RDNA.

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

@amd-hhashemi Oh wow, yeah, changes look similar.
@mgehre-amd Could you please take a look. We think that this work is correlated to what you are doing. Is this one here a duplicate PR?

@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Feb 19, 2026

@amd-hhashemi @AndreasKaratzas @mgehre-amd There's overlap for sure — both PRs enable wvSplitK on RDNA wave32. We got to similar places independently.

This PR goes further though: it covers gfx12/RDNA4 (not just gfx11), adds full FP8 wvSplitKQ support for gfx12 via dot4_f32_fp8_fp8 scalar accumulation, dead-codes MFMA paths with #ifdef to avoid Clang parsing gfx9 asm in dead branches, and includes dispatch unit tests. Macro naming follows the existing __HIP__GFX1X__ convention in the file.

One thing worth picking up from #34177: the atol = eps * sqrt(k) test tolerance — that's a nice fix.

Happy to coordinate with @mgehre-amd on how to proceed.

@amd-hhashemi
Copy link
Copy Markdown
Contributor

amd-hhashemi commented Feb 20, 2026

Also note this likely has conflicts with #33762 (which converts asm to builtins and adds padding support to wvSplitK; as has been seen in some recent models) and #34100 (which converts the MFMAs in wvSplitK_fp8 to 16x16; as it attains better perf with lower reg pressure). But I'm fine with this or #34177 going first and resolving conflicts after.

}
}
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX1X__)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX1X__)
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__)

gfx11 doesn't have fp8 support.

@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Mar 19, 2026

Hey, just pushed the clang-format fix as a single commit on top — no force push.

I also checked against latest main and there are no merge conflicts, so a rebase shouldn't be necessary.

@amd-hhashemi
Copy link
Copy Markdown
Contributor

Hey, just pushed the clang-format fix as a single commit on top — no force push.

I also checked against latest main and there are no merge conflicts, so a rebase shouldn't be necessary.

Thanks, you might need to rebase (or sync) due to CI issues, not your code.

L.B.R. added 7 commits March 19, 2026 17:27
gfx1x (RDNA4/gfx12) hardware does not yet have full wvSplitK support.
This adds a dedicated code path that routes n==1 decode GEMMs through
the LLMM1 kernel on gfx1x, with appropriate dimension and dtype
guards. Includes unit tests verifying the dispatch and fallback.

Signed-off-by: L.B.R. <lbr@mmonad.com>
Enable the wvSplitK skinny GEMM kernel on RDNA (gfx11/gfx12) hardware
for decode-phase unquantized GEMMs (N=1..4). Previously these kernels
were gfx9-only (MI-series), falling through to torch linear on RDNA.

Key changes:
- Add `__HIP__GFX1X__` compile-time macro for gfx1100-gfx1201
- Add `on_gfx1x()` runtime detection for wave32 dispatch
- Parameterize kernel launch macros (WVSPLITK_CFG/WVSPLIT_TILE_CFG)
  on THRDS and WvPrGrp to support both wave64 (gfx9) and wave32 (RDNA)
- Replace gfx9-only DPP reduction (wave_shr:1, row_bcast:15/31) with
  RDNA wave32 alternative (row_shr:1 + __shfl_xor(val, 16))
- Add gfx1x DOT2C using v_dot2_f32_f16 inline asm (VOP3-P encoding,
  dot10-insts) instead of gfx9's v_dot2c_f32_f16 (VOP2)
- Guard MFMA bf16 paths with #ifdef __HIP__GFX9__ (prevents Clang
  from validating gfx9-only inline asm in dead if-constexpr branches)
- Unify Python dispatch: gfx1x now shares the same wvSplitK/LLMM1
  path as gfx9 (removes the prior LLMM1-only gfx1x workaround)

~15% decode token/s improvement on AMD Radeon AI PRO R9700 (gfx1201).

Tested: Qwen3-Coder-30B-A3B AWQ-4bit inference on gfx1201.
Signed-off-by: L.B.R. <lbr@mmonad.com>
Port the wvSplitKQ FP8 skinny GEMM kernel to RDNA4 (gfx12) using
scalar v_dot4_f32_fp8_fp8 instructions instead of MFMA/WMMA matrix
ops. This accelerates FP8 per-tensor scaled matmuls during decode
(M<=4) on RDNA4 GPUs.

Key changes:
- Add __HIP__GFX12__ compile-time macro for gfx1200/gfx1201
- Add on_gfx12() runtime detection helper
- Both wvSplitKQ_hf_sml_ and wvSplitKQ_hf_ kernels gain gfx1x path:
  - Scalar float accumulator (vs floatx16 MFMA vector on gfx9)
  - __builtin_amdgcn_dot4_f32_fp8_fp8() for 4-way FP8 dot products
  - DPP wave32 reduction: row_shr:8/4/2/1 + __shfl_xor(val, 16)
- Host dispatch: WVSPLITKQ macro gains _THRDS param, runtime
  on_gfx1x() selects THRDS=32/WvPrGrp=16 for wave32

Signed-off-by: L.B.R. <lbr@mmonad.com>
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to
__HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous
guard let gfx11 fall through to the MFMA #else path it can't execute.

Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly
model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177).

Signed-off-by: L.B.R. <lbr@mmonad.com>
Addresses review feedback from @amd-hhashemi: since use_skinny
is already checked with an early return to torch.nn.functional.linear,
the subsequent `if use_skinny and` conditions are redundant.

Signed-off-by: L.B.R. <lbr@mmonad.com>
Co-authored-by: Claude
Signed-off-by: Bren Norris <bnorris@mmonad.com>

Signed-off-by: L.B.R. <lbr@mmonad.com>
@laudney laudney force-pushed the feat/rocm-rdna4-wvsplitk branch from 1521437 to 9ed6beb Compare March 19, 2026 17:27
@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Mar 19, 2026

Rebased onto latest main and force pushed.

@JartX
Copy link
Copy Markdown
Contributor

JartX commented Mar 19, 2026

Hi @laudney Would you be so kind as to try this branch? https://github.com/JartX/vllm/commits/fix/rocm-cudagraph-memory-profiling-startup-oom-performance/ It's based on yours; I'd say there's about 5% more performance, but no more.

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

@laudney Thank you for the contribution! These look legit:

[2026-03-19T18:20:04Z] FAILED model_executor/layers/test_rocm_unquantized_gemm.py::test_rocm_unquantized_gemm_gfx1x_wvsplitk_path - AttributeError: <module 'vllm.model_executor.layers.utils' from '/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/utils.py'> has no attribute 'get_cu_count'
[2026-03-19T18:20:04Z] FAILED model_executor/layers/test_rocm_unquantized_gemm.py::test_rocm_unquantized_gemm_gfx1x_n_gt_4_falls_back - AttributeError: <module 'vllm.model_executor.layers.utils' from '/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/utils.py'> has no attribute 'get_cu_count'
[2026-03-19T18:20:04Z] FAILED model_executor/layers/test_rocm_unquantized_gemm.py::test_rocm_unquantized_gemm_gfx950_wvsplitkrc_path - AttributeError: <module 'vllm.model_executor.layers.utils' from '/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/utils.py'> has no attribute 'get_cu_count'

I think that you are supposed to add a skip for CUDA/NVIDIA arch. Other than that CI should be green.

Copy link
Copy Markdown
Collaborator

@AndreasKaratzas AndreasKaratzas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM make this change so that CI is green.

import torch

from vllm.model_executor.layers import utils

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add something like:

from vllm.platforms import current_platform

if current_platform.is_cuda():
    pytest.skip (or something similar) -- (and the reason for skipping (i.e. It's a rocm specific test)

Co-authored-by: Claude
Signed-off-by: Bren Norris <bnorris@mmonad.com>

Signed-off-by: L.B.R. <lbr@mmonad.com>
@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Mar 19, 2026

Thanks @AndreasKaratzas, good catch. Just pushed a commit adding the skipif for non-ROCm platforms on those tests.

Copy link
Copy Markdown
Collaborator

@AndreasKaratzas AndreasKaratzas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@laudney Not quite what I meant 😅

from vllm.platforms import current_platform


@pytest.mark.skipif(not current_platform.is_rocm(), reason="ROCm-only test")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's best if you put a "header" skip for this entire file and DRY it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check out vllm/tests/kernels/attention/test_trtllm_kvfp8_dequant.py:

if current_platform.is_rocm():
    pytest.skip(
        "trtllm kvfp8 dequant is not supported on ROCm.",
        allow_module_level=True,
    )

Mirror the logic but for CUDA.

Co-authored-by: Claude
Signed-off-by: Bren Norris <bnorris@mmonad.com>

Signed-off-by: L.B.R. <lbr@mmonad.com>
@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Mar 19, 2026

Ah my bad, switched to module-level pytest.skip like test_trtllm_kvfp8_dequant.py. Should be good now.

@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Mar 20, 2026

Hey, is there anything else I need to do to get this merged? This PR has been approved 3 times (by @mgehre-amd, @amd-hhashemi, @gshtras) and CI is green except for an unrelated elastic-ep-scaling-test failure. What's the holdup?

@gshtras gshtras merged commit 1779c09 into vllm-project:main Mar 20, 2026
135 of 136 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 20, 2026
chooper26 pushed a commit to intellistream/vllm-hust that referenced this pull request Mar 21, 2026
…m-project#34709)

Signed-off-by: L.B.R. <lbr@mmonad.com>
Co-authored-by: L.B.R. <lbr@mmonad.com>
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
…m-project#34709)

Signed-off-by: L.B.R. <lbr@mmonad.com>
Co-authored-by: L.B.R. <lbr@mmonad.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…m-project#34709)

Signed-off-by: L.B.R. <lbr@mmonad.com>
Co-authored-by: L.B.R. <lbr@mmonad.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…m-project#34709)

Signed-off-by: L.B.R. <lbr@mmonad.com>
Co-authored-by: L.B.R. <lbr@mmonad.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
…m-project#34709)

Signed-off-by: L.B.R. <lbr@mmonad.com>
Co-authored-by: L.B.R. <lbr@mmonad.com>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…m-project#34709)

Signed-off-by: L.B.R. <lbr@mmonad.com>
Co-authored-by: L.B.R. <lbr@mmonad.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build new-model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

8 participants