[ROCm] Split wvSplitKrc into deterministic/fast kernels, add --fast-skinny-gemm CLI flag, overhaul tests#35183
Conversation
torch.linear(). Adjust reduce buffer usage to avoid need for zero_() calls. Move priority above Aiter solution. Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
…_skinny_gemm_determinism
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
There was a problem hiding this comment.
Code Review
This pull request introduces significant improvements to the ROCm skinny GEMM kernels. It splits the wvSplitKrc kernel into deterministic and fast versions, controlled by a new --fast-skinny-gemm CLI flag. This allows for reproducible results by default while providing an option for higher performance. The test suite has also been completely overhauled, resulting in faster and more comprehensive testing.
The changes are well-structured, but I have identified a few critical issues that should be addressed:
- There is a large amount of code duplication between the new
wvSplitKrc_deterministic_andwvSplitKrc_fast_kernels, which will be a maintenance burden. - A potential buffer overflow exists due to an insufficient guard on the static buffer size for
wvSplitKrc. - A critical bug was introduced by removing necessary tensor reshape operations when calling
wvSplitKrc, which will cause issues for inputs that are not 2D.
| return ops.wvSplitKrc( | ||
| x, weight, cu_count, bias, fast_skinny_gemm=_fast_skinny_gemm | ||
| ) |
There was a problem hiding this comment.
The call to ops.wvSplitKrc is missing the necessary reshape operations for the input and output tensors. The previous implementation correctly reshaped the input x to 2D before the kernel call and reshaped the output back to the original shape's prefix. The wvSplitKrc kernel expects a 2D input tensor (num_tokens, k) and returns a 2D output (num_tokens, m). If x is a 3D tensor (batch_size, seq_len, k), passing it directly will lead to incorrect dimensions being used inside the kernel and an incorrectly shaped output. This is a critical bug.
| return ops.wvSplitKrc( | |
| x, weight, cu_count, bias, fast_skinny_gemm=_fast_skinny_gemm | |
| ) | |
| x_view = x.reshape(-1, x.size(-1)) | |
| out = ops.wvSplitKrc( | |
| x_view, weight, cu_count, bias, fast_skinny_gemm=_fast_skinny_gemm | |
| ) | |
| return out.reshape(*x.shape[:-1], weight.shape[0]) |
| template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, | ||
| int UNRL, int N, int GrpsShrB, int CHUNKK> | ||
| __global__ void __launch_bounds__(WvPrGrp* THRDS) | ||
| __attribute__((amdgpu_waves_per_eu(1, 1))) | ||
| wvSplitKrc_fast_(const int actlN, const int K, const int Kap, const int M, | ||
| const int Bx, const int By, const scalar_t* __restrict__ A, | ||
| const scalar_t* __restrict__ B, | ||
| const scalar_t* __restrict__ BIAS, float* glbl, | ||
| scalar_t* C, const int CuCount) { |
There was a problem hiding this comment.
There is a very large amount of code duplication between wvSplitKrc_fast_ and wvSplitKrc_deterministic_. The only substantial difference is the value of DTRMNSTC. This will make maintenance difficult, as any changes to the common logic will need to be applied in two places.
Consider refactoring this into a single templated implementation, for example:
template <bool DTRMNSTC>
__global__ void __launch_bounds__(...)
wvSplitKrc_impl_(...) {
// ... common kernel logic ...
}
// Then define the two kernels as thin wrappers:
__global__ void __launch_bounds__(...)
wvSplitKrc_deterministic_(...) {
wvSplitKrc_impl_<true>(...);
}
__global__ void __launch_bounds__(...)
wvSplitKrc_fast_(...) {
wvSplitKrc_impl_<false>(...);
}Since DTRMNSTC is a compile-time constant (constexpr), the compiler will optimize away the dead branches related to it, achieving the same performance benefit as two separate kernels without the code duplication.
| and k % 8 == 0 | ||
| and k > 512 | ||
| and m % 16 == 0 | ||
| and m * n <= 128 * 1024 # max reduce buffer |
There was a problem hiding this comment.
The guard m * n <= 128 * 1024 is not sufficient to prevent a buffer overflow in the wvSplitKrc kernel. The memory required by the deterministic path depends on k_rnd, which is proportional to the K dimension. A very large K could cause an overflow even if m * n is within the limit.
A more robust check should also consider the K dimension. The logic to calculate the required buffer size is present in the wvSplitKrc host function and can be replicated here. Specifically, k_rnd is roughly (k + kFit - 1) // kFit, where kFit depends on chunkk, which in turn depends on CuNeeded. All this information is available in this function. The check should be something like m * n * k_rnd <= 8 * 1024 * 1024.
|
@amd-hhashemi let me know your thoughts on:
|
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
| fast_skinny_gemm: bool = False | ||
| """When enabled on ROCm, uses the non-deterministic atomicAdd reduction | ||
| path in skinny GEMM kernels (wvSplitKrc) for higher throughput. The | ||
| default (False) uses a deterministic store-then-reduce path that | ||
| guarantees bitwise reproducibility across runs with different batch | ||
| dimensions, at a small cost in LDS pressure and an extra sync.""" |
There was a problem hiding this comment.
Should/could this belong to KernelConfig?
There was a problem hiding this comment.
Yep, that's a better place. Thanks for the suggestion :)
|
This pull request has merge conflicts that must be resolved before it can be |
|
Once available, could we have perf comparison for the 2 paths? |
|
Closing this PR after #34304 was merged. Determinism is now on by default: https://github.com/amd-hhashemi/vllm/blob/4cecae9c3fd52179f8bd5cd4fb36ff05d086cd5d/csrc/rocm/skinny_gemms.cu#L1670 |
Follow-up to #34304, which introduced the deterministic store-then-reduce path for
wvSplitKrc. This PR separates the two reduction strategies into distinct kernels (wvSplitKrc_deterministic_andwvSplitKrc_fast_), adds a CLI flag to opt into the fast path, and replaces the test suite with a principled one that runs in under 7 minutes instead of 90+.Motivation
The original #34304 implementation used a single
wvSplitKrc_kernel with a runtimeDTRMNSTCconstant selecting between deterministic (store-then-reduce) and non-deterministic (atomicAdd) reduction. Splitting into two kernels, i.e.,wvSplitKrc_deterministic_andwvSplitKrc_fast_, lets the compiler optimize each independently without carrying dead code.The default is deterministic (
--fast-skinny-gemm=False). The rationale is the same as compiler optimization levels: users should get correct, reproducible results by default and opt into speed when they've validated it's acceptable for their workload. Bitwise reproducibility matters for speculative decoding verification, debugging, and regulatory compliance (these shouldn't require a flag to enable).Borrowed from: #34304 cc @amd-hhashemi
Kernel changes
wvSplitKrc_deterministic_stores each K-shard's partial sums to a dedicated slice of the global buffer (glbl[... + M*N*shard_idx]), then the last shard to finish loads all slices via__builtin_amdgcn_global_load_ldsand sums them in LDS.wvSplitKrc_fast_usesatomicAddto accumulate directly, which is faster but non-deterministic across runs with different CU scheduling.The
wvSplitKrchost function now correctly readsMfromin_b.size(0)andNfromin_a.size(0), matching the kernel's actual A/B roles. It also passesKap_in = in_a.stride(0)to support padded (non-contiguous stride) input tensors (the kernels useKapinstead ofKfor address calculations on A).The per-call
torch::empty+.zero_()for the reduction buffer is replaced with a persistentstatic torch::Tensorof 8M floats (32 MB), avoiding allocation overhead on every GEMM call. The deterministic kernel clears its slices as part of the reduction, and the fast kernel clears via zeroing after the last shard reads back.In
rocm_unquantized_gemm_impl,wvSplitKrcis now checked before the aiter triton GEMM fallback, with an added guardm * n <= 128 * 1024to stay within the static buffer capacity. The contiguity check was corrected fromx.is_contiguous()toweight.is_contiguous()(the kernel needs contiguous B/weight, not A/activations, since A can have stride padding).Original Contributions
CLI integration
New
--fast-skinny-gemmflag onModelConfig/EngineArgs, plumbed throughGPUModelRunner->set_fast_skinny_gemm()->rocm_unquantized_gemm_impl. Only affects ROCmwvSplitKrccalls.Test overhaul
The previous test suite ran for 82+ minutes (12,847 passed, 1,530 skipped) before being interrupted, with significant time wasted on redundant parametrization and shapes that were skipped at runtime due to CU capacity. The new suite runs in under 7 minutes with the following changes:
Full cartesian products of distributions x bias modes x shapes x seeds are replaced with targeted combinations. Dimension sweeps use representative subsets (
N_RC = [13, 16, 32, 64, 103, 128]) instead of exhaustive lists. Seed parametrization is removed (seeds are fixed per-test).Each kernel gets a tolerance function derived from its accumulation model:
wvSplitK/wvSplitKrc: fp32 accumulation -> 1 ULP of output dtype (~7.8e-3 bf16, ~9.8e-4 fp16)LLMM1: dtype-precision FMA accumulation -> 5 ULP (5x wider due to__hfma2rounding)wvSplitKQ(FP8): quantization noiseO(sqrt(K) * eps_fp8)+ output rounding, capped at 0.05Instead of a single
assert_close,_assert_accuratechecks: (1) greater or equal 99.999% of elements within tolerance, (2) no element exceeds 3x tolerance, (3) mean absolute error < 0.25x atol. This catches both systematic bias and catastrophic single-element failures.New test categories:
test_e2e_logprob_reproducibilityandtest_e2e_logprob_stabilityusingvllm_runnerwithTitanML/tiny-mixtral, checking bitwise logprob identity and less or equal to 0.001 nat drift across runsThe previous
test_cross_kernel_consistencycomparedwvSplitK(N less or equal to 4) againstwvSplitKrc(N greater or equal to 9). These have zero overlap in supported N ranges, so every test case hit"Unsupported N value"at runtime.Related