[ROCm] Enable wvSplitK skinny GEMM kernel for RDNA4/gfx1x decode#34709
[ROCm] Enable wvSplitK skinny GEMM kernel for RDNA4/gfx1x decode#34709gshtras merged 9 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request is a significant contribution that enables the wvSplitK skinny GEMM kernel for RDNA4/gfx1x architectures on ROCm, which should provide a notable performance improvement for decode-phase GEMMs. The changes are comprehensive, spanning from the CUDA kernel implementation to the Python dispatch logic, dependency management, and testing. The kernel modifications are well-structured, using preprocessor directives to handle different GPU architectures and parameterizing macros for wave32/64 support. The Python code updates correctly integrate the new kernel path and include a bug fix in the dispatch logic. The addition of new tests for the added functionality is also a great touch. I have one suggestion to improve the thread safety of a newly added utility function.
| bool on_gfx1x() { | ||
| static bool is_cached = false; | ||
| static bool result = false; | ||
| if (!is_cached) { | ||
| auto dprops = at::cuda::getCurrentDeviceProperties(); | ||
| std::string device_arch = dprops->gcnArchName; | ||
| result = | ||
| device_arch.find("gfx11") != std::string::npos || | ||
| device_arch.find("gfx12") != std::string::npos; | ||
| is_cached = true; | ||
| } | ||
| return result; | ||
| } |
There was a problem hiding this comment.
The current implementation of on_gfx1x is not thread-safe. If multiple threads call this function for the first time concurrently, it can lead to a data race on the is_cached and result static variables. While this might be a benign race in many cases, it's better to avoid undefined behavior.
You can make this function thread-safe and more concise by using a lambda to initialize a static const variable. C++11 and later guarantee that the initialization of local static variables is thread-safe.
bool on_gfx1x() {
static const bool result = [] {
const auto* dprops = at::cuda::getCurrentDeviceProperties();
const std::string device_arch = dprops->gcnArchName;
return device_arch.find("gfx11") != std::string::npos ||
device_arch.find("gfx12") != std::string::npos;
}();
return result;
}
|
Hi @laudney, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Hi @laudney, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
4984778 to
afbdd14
Compare
|
Hi @laudney, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
afbdd14 to
242cddf
Compare
|
@amd-hhashemi Could you please take a look into this? This passes the AMD CI tests with no new regressions at least. But I would like a more experienced eye into this if possible :) |
|
Is this related to #34177 |
No, you were the only person in my head that I know to have knowledge on the subject thus the ability to also review this. Feel free to forward it to someone else if you don't have the time of course 😅 |
|
Sorry I gave the wrong link. I meant #34177 which is from Mathias. It also targets skinny gemm on RDNA. |
|
@amd-hhashemi Oh wow, yeah, changes look similar. |
|
@amd-hhashemi @AndreasKaratzas @mgehre-amd There's overlap for sure — both PRs enable wvSplitK on RDNA wave32. We got to similar places independently. This PR goes further though: it covers gfx12/RDNA4 (not just gfx11), adds full FP8 wvSplitKQ support for gfx12 via One thing worth picking up from #34177: the Happy to coordinate with @mgehre-amd on how to proceed. |
|
Also note this likely has conflicts with #33762 (which converts asm to builtins and adds padding support to wvSplitK; as has been seen in some recent models) and #34100 (which converts the MFMAs in wvSplitK_fp8 to 16x16; as it attains better perf with lower reg pressure). But I'm fine with this or #34177 going first and resolving conflicts after. |
csrc/rocm/skinny_gemms.cu
Outdated
| } | ||
| } | ||
| #else // !defined(__HIP__MI3XX__) TODO: Add NAVI support | ||
| #else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX1X__) |
There was a problem hiding this comment.
| #else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX1X__) | |
| #else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__) |
gfx11 doesn't have fp8 support.
|
Hey, just pushed the clang-format fix as a single commit on top — no force push. I also checked against latest main and there are no merge conflicts, so a rebase shouldn't be necessary. |
Thanks, you might need to rebase (or sync) due to CI issues, not your code. |
gfx1x (RDNA4/gfx12) hardware does not yet have full wvSplitK support. This adds a dedicated code path that routes n==1 decode GEMMs through the LLMM1 kernel on gfx1x, with appropriate dimension and dtype guards. Includes unit tests verifying the dispatch and fallback. Signed-off-by: L.B.R. <lbr@mmonad.com>
Enable the wvSplitK skinny GEMM kernel on RDNA (gfx11/gfx12) hardware for decode-phase unquantized GEMMs (N=1..4). Previously these kernels were gfx9-only (MI-series), falling through to torch linear on RDNA. Key changes: - Add `__HIP__GFX1X__` compile-time macro for gfx1100-gfx1201 - Add `on_gfx1x()` runtime detection for wave32 dispatch - Parameterize kernel launch macros (WVSPLITK_CFG/WVSPLIT_TILE_CFG) on THRDS and WvPrGrp to support both wave64 (gfx9) and wave32 (RDNA) - Replace gfx9-only DPP reduction (wave_shr:1, row_bcast:15/31) with RDNA wave32 alternative (row_shr:1 + __shfl_xor(val, 16)) - Add gfx1x DOT2C using v_dot2_f32_f16 inline asm (VOP3-P encoding, dot10-insts) instead of gfx9's v_dot2c_f32_f16 (VOP2) - Guard MFMA bf16 paths with #ifdef __HIP__GFX9__ (prevents Clang from validating gfx9-only inline asm in dead if-constexpr branches) - Unify Python dispatch: gfx1x now shares the same wvSplitK/LLMM1 path as gfx9 (removes the prior LLMM1-only gfx1x workaround) ~15% decode token/s improvement on AMD Radeon AI PRO R9700 (gfx1201). Tested: Qwen3-Coder-30B-A3B AWQ-4bit inference on gfx1201. Signed-off-by: L.B.R. <lbr@mmonad.com>
Port the wvSplitKQ FP8 skinny GEMM kernel to RDNA4 (gfx12) using scalar v_dot4_f32_fp8_fp8 instructions instead of MFMA/WMMA matrix ops. This accelerates FP8 per-tensor scaled matmuls during decode (M<=4) on RDNA4 GPUs. Key changes: - Add __HIP__GFX12__ compile-time macro for gfx1200/gfx1201 - Add on_gfx12() runtime detection helper - Both wvSplitKQ_hf_sml_ and wvSplitKQ_hf_ kernels gain gfx1x path: - Scalar float accumulator (vs floatx16 MFMA vector on gfx9) - __builtin_amdgcn_dot4_f32_fp8_fp8() for 4-way FP8 dot products - DPP wave32 reduction: row_shr:8/4/2/1 + __shfl_xor(val, 16) - Host dispatch: WVSPLITKQ macro gains _THRDS param, runtime on_gfx1x() selects THRDS=32/WvPrGrp=16 for wave32 Signed-off-by: L.B.R. <lbr@mmonad.com>
Signed-off-by: L.B.R. <lbr@mmonad.com>
Narrow wvSplitKQ kernel compile guards from __HIP__GFX1X__ to __HIP__GFX12__ since gfx11 (RDNA3) has no FP8 hardware — the previous guard let gfx11 fall through to the MFMA #else path it can't execute. Apply eps * sqrt(k) absolute tolerance to wvSplitK tests to properly model fp16 accumulation error scaling (credit: @mgehre-amd vllm-project#34177). Signed-off-by: L.B.R. <lbr@mmonad.com>
Addresses review feedback from @amd-hhashemi: since use_skinny is already checked with an early return to torch.nn.functional.linear, the subsequent `if use_skinny and` conditions are redundant. Signed-off-by: L.B.R. <lbr@mmonad.com>
Co-authored-by: Claude Signed-off-by: Bren Norris <bnorris@mmonad.com> Signed-off-by: L.B.R. <lbr@mmonad.com>
1521437 to
9ed6beb
Compare
|
Rebased onto latest main and force pushed. |
|
Hi @laudney Would you be so kind as to try this branch? https://github.com/JartX/vllm/commits/fix/rocm-cudagraph-memory-profiling-startup-oom-performance/ It's based on yours; I'd say there's about 5% more performance, but no more. |
|
@laudney Thank you for the contribution! These look legit: [2026-03-19T18:20:04Z] FAILED model_executor/layers/test_rocm_unquantized_gemm.py::test_rocm_unquantized_gemm_gfx1x_wvsplitk_path - AttributeError: <module 'vllm.model_executor.layers.utils' from '/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/utils.py'> has no attribute 'get_cu_count' I think that you are supposed to add a skip for CUDA/NVIDIA arch. Other than that CI should be green. |
AndreasKaratzas
left a comment
There was a problem hiding this comment.
LGTM make this change so that CI is green.
| import torch | ||
|
|
||
| from vllm.model_executor.layers import utils | ||
|
|
There was a problem hiding this comment.
Add something like:
from vllm.platforms import current_platform
if current_platform.is_cuda():
pytest.skip (or something similar) -- (and the reason for skipping (i.e. It's a rocm specific test)Co-authored-by: Claude Signed-off-by: Bren Norris <bnorris@mmonad.com> Signed-off-by: L.B.R. <lbr@mmonad.com>
|
Thanks @AndreasKaratzas, good catch. Just pushed a commit adding the skipif for non-ROCm platforms on those tests. |
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not current_platform.is_rocm(), reason="ROCm-only test") |
There was a problem hiding this comment.
I think it's best if you put a "header" skip for this entire file and DRY it.
There was a problem hiding this comment.
Check out vllm/tests/kernels/attention/test_trtllm_kvfp8_dequant.py:
if current_platform.is_rocm():
pytest.skip(
"trtllm kvfp8 dequant is not supported on ROCm.",
allow_module_level=True,
)Mirror the logic but for CUDA.
Co-authored-by: Claude Signed-off-by: Bren Norris <bnorris@mmonad.com> Signed-off-by: L.B.R. <lbr@mmonad.com>
|
Ah my bad, switched to module-level pytest.skip like test_trtllm_kvfp8_dequant.py. Should be good now. |
|
Hey, is there anything else I need to do to get this merged? This PR has been approved 3 times (by @mgehre-amd, @amd-hhashemi, @gshtras) and CI is green except for an unrelated elastic-ep-scaling-test failure. What's the holdup? |
…m-project#34709) Signed-off-by: L.B.R. <lbr@mmonad.com> Co-authored-by: L.B.R. <lbr@mmonad.com>
…m-project#34709) Signed-off-by: L.B.R. <lbr@mmonad.com> Co-authored-by: L.B.R. <lbr@mmonad.com>
…m-project#34709) Signed-off-by: L.B.R. <lbr@mmonad.com> Co-authored-by: L.B.R. <lbr@mmonad.com>
…m-project#34709) Signed-off-by: L.B.R. <lbr@mmonad.com> Co-authored-by: L.B.R. <lbr@mmonad.com> Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
…m-project#34709) Signed-off-by: L.B.R. <lbr@mmonad.com> Co-authored-by: L.B.R. <lbr@mmonad.com> Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
…m-project#34709) Signed-off-by: L.B.R. <lbr@mmonad.com> Co-authored-by: L.B.R. <lbr@mmonad.com>
Summary
Enable the
wvSplitKandwvSplitKQskinny GEMM kernels on RDNA (gfx11/gfx12) hardware for decode-phase GEMMs (M=1..4). Previously these kernels were gfx9-only (MI-series), with RDNA falling through totorch.nn.functional.linear(unquantized) or generictorch._scaled_mm(FP8).~15% decode token/s improvement on AMD Radeon AI PRO R9700 (gfx1201).
Commit 1:
wvSplitK— unquantized BF16/FP16 skinny GEMMKernel changes (
csrc/rocm/skinny_gemms.cu)__HIP__GFX1X__compile-time macro for gfx1100/1101/1150/1151/1200/1201on_gfx1x()runtime detection for wave32 vs wave64 dispatchWVSPLITK_CFG/WVSPLIT_TILE_CFG) accept THRDS and WvPrGrp to support both wave64 (gfx9) and wave32 (RDNA)wave_shr:1,row_bcast:15,row_bcast:31) with RDNA wave32 alternative (row_shr:1+__shfl_xor(val, 16))v_dot2_f32_f16inline asm (VOP3-P encoding,dot10-insts) instead of gfx9'sv_dot2c_f32_f16(VOP2)#ifdef __HIP__GFX9__around MFMA bf16 paths (prevents Clang from validating gfx9-only inline asm in deadif constexprbranches)device_guardmoved before tensor allocations inwvSplitKrcPython dispatch changes (
vllm/model_executor/layers/utils.py)use_skinnycondition now includeson_gfx1x()alongsideon_gfx9()use_skinny_llmm1_gfx1x)wvSplitK→LLMM1dispatch path as gfx9Commit 2:
wvSplitKQ— FP8 per-tensor skinny GEMMKernel changes (
csrc/rocm/skinny_gemms.cu)__HIP__GFX12__compile-time macro for gfx1200/gfx1201on_gfx12()runtime detection for FP8-capable RDNA dispatchwvSplitKQ_hf_sml_andwvSplitKQ_hf_kernels gain a gfx1x path:floatx16MFMA vector on gfx9)__builtin_amdgcn_dot4_f32_fp8_fp8()for 4-way FP8 dot products — hardware interprets FP8 bits per target arch (fnuz on gfx9, fn on gfx12)row_shr:8/4/2/1+__shfl_xor(val, 16)threadIdx.x == (THRDS - 1)works for both wave sizesWVSPLITKQmacro gains_THRDSparameter, runtimeon_gfx1x()selects THRDS=32/WvPrGrp=16 for wave32Related PRs (RDNA4/gfx12 series)
Test plan
test_rocm_wvsplitk_kernel,test_rocm_unquantized_gemmdispatch tests