Skip to content

[ROCm] Split wvSplitKrc into deterministic/fast kernels, add --fast-skinny-gemm CLI flag, overhaul tests#35183

Closed
AndreasKaratzas wants to merge 15 commits intovllm-project:mainfrom
ROCm:akaratza_skinny_gemm_determinism
Closed

[ROCm] Split wvSplitKrc into deterministic/fast kernels, add --fast-skinny-gemm CLI flag, overhaul tests#35183
AndreasKaratzas wants to merge 15 commits intovllm-project:mainfrom
ROCm:akaratza_skinny_gemm_determinism

Conversation

@AndreasKaratzas
Copy link
Collaborator

@AndreasKaratzas AndreasKaratzas commented Feb 24, 2026

Follow-up to #34304, which introduced the deterministic store-then-reduce path for wvSplitKrc. This PR separates the two reduction strategies into distinct kernels (wvSplitKrc_deterministic_ and wvSplitKrc_fast_), adds a CLI flag to opt into the fast path, and replaces the test suite with a principled one that runs in under 7 minutes instead of 90+.

Motivation

The original #34304 implementation used a single wvSplitKrc_ kernel with a runtime DTRMNSTC constant selecting between deterministic (store-then-reduce) and non-deterministic (atomicAdd) reduction. Splitting into two kernels, i.e., wvSplitKrc_deterministic_ and wvSplitKrc_fast_, lets the compiler optimize each independently without carrying dead code.

The default is deterministic (--fast-skinny-gemm=False). The rationale is the same as compiler optimization levels: users should get correct, reproducible results by default and opt into speed when they've validated it's acceptable for their workload. Bitwise reproducibility matters for speculative decoding verification, debugging, and regulatory compliance (these shouldn't require a flag to enable).

Borrowed from: #34304 cc @amd-hhashemi

Kernel changes

wvSplitKrc_deterministic_ stores each K-shard's partial sums to a dedicated slice of the global buffer (glbl[... + M*N*shard_idx]), then the last shard to finish loads all slices via __builtin_amdgcn_global_load_lds and sums them in LDS. wvSplitKrc_fast_ uses atomicAdd to accumulate directly, which is faster but non-deterministic across runs with different CU scheduling.

The wvSplitKrc host function now correctly reads M from in_b.size(0) and N from in_a.size(0), matching the kernel's actual A/B roles. It also passes Kap_in = in_a.stride(0) to support padded (non-contiguous stride) input tensors (the kernels use Kap instead of K for address calculations on A).

The per-call torch::empty + .zero_() for the reduction buffer is replaced with a persistent static torch::Tensor of 8M floats (32 MB), avoiding allocation overhead on every GEMM call. The deterministic kernel clears its slices as part of the reduction, and the fast kernel clears via zeroing after the last shard reads back.

In rocm_unquantized_gemm_impl, wvSplitKrc is now checked before the aiter triton GEMM fallback, with an added guard m * n <= 128 * 1024 to stay within the static buffer capacity. The contiguity check was corrected from x.is_contiguous() to weight.is_contiguous() (the kernel needs contiguous B/weight, not A/activations, since A can have stride padding).

Original Contributions

CLI integration

New --fast-skinny-gemm flag on ModelConfig / EngineArgs, plumbed through GPUModelRunner -> set_fast_skinny_gemm() -> rocm_unquantized_gemm_impl. Only affects ROCm wvSplitKrc calls.

Test overhaul

The previous test suite ran for 82+ minutes (12,847 passed, 1,530 skipped) before being interrupted, with significant time wasted on redundant parametrization and shapes that were skipped at runtime due to CU capacity. The new suite runs in under 7 minutes with the following changes:

Full cartesian products of distributions x bias modes x shapes x seeds are replaced with targeted combinations. Dimension sweeps use representative subsets (N_RC = [13, 16, 32, 64, 103, 128]) instead of exhaustive lists. Seed parametrization is removed (seeds are fixed per-test).

Each kernel gets a tolerance function derived from its accumulation model:

  • wvSplitK / wvSplitKrc: fp32 accumulation -> 1 ULP of output dtype (~7.8e-3 bf16, ~9.8e-4 fp16)
  • LLMM1: dtype-precision FMA accumulation -> 5 ULP (5x wider due to __hfma2 rounding)
  • wvSplitKQ (FP8): quantization noise O(sqrt(K) * eps_fp8) + output rounding, capped at 0.05

Instead of a single assert_close, _assert_accurate checks: (1) greater or equal 99.999% of elements within tolerance, (2) no element exceeds 3x tolerance, (3) mean absolute error < 0.25x atol. This catches both systematic bias and catastrophic single-element failures.

New test categories:

  • Determinism: bitwise-identical output across 10 runs (deterministic path only)
  • Logprobs: top-1 agreement greater or equal to 99%, top-5 overlap greater or equal to 95%, top-1 logprob diff less or equal to 0.01 nats
  • NaN propagation: NaN in row 0 must appear in output row 0, must not leak to row 1
  • Zero/bias: zero inputs produce exact zero; zero inputs + bias produce exact bias
  • Distributions: normal, mixed_scale (10% hot channels), sparse_activations (ReLU'd)
  • E2E: test_e2e_logprob_reproducibility and test_e2e_logprob_stability using vllm_runner with TitanML/tiny-mixtral, checking bitwise logprob identity and less or equal to 0.001 nat drift across runs

The previous test_cross_kernel_consistency compared wvSplitK (N less or equal to 4) against wvSplitKrc (N greater or equal to 9). These have zero overlap in supported N ranges, so every test case hit "Unsupported N value" at runtime.

Related

amd-hhashemi and others added 11 commits February 11, 2026 02:58
torch.linear(). Adjust reduce buffer usage to avoid need for zero_()
calls. Move priority above Aiter solution.

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant improvements to the ROCm skinny GEMM kernels. It splits the wvSplitKrc kernel into deterministic and fast versions, controlled by a new --fast-skinny-gemm CLI flag. This allows for reproducible results by default while providing an option for higher performance. The test suite has also been completely overhauled, resulting in faster and more comprehensive testing.

The changes are well-structured, but I have identified a few critical issues that should be addressed:

  • There is a large amount of code duplication between the new wvSplitKrc_deterministic_ and wvSplitKrc_fast_ kernels, which will be a maintenance burden.
  • A potential buffer overflow exists due to an insufficient guard on the static buffer size for wvSplitKrc.
  • A critical bug was introduced by removing necessary tensor reshape operations when calling wvSplitKrc, which will cause issues for inputs that are not 2D.

Comment on lines +188 to +190
return ops.wvSplitKrc(
x, weight, cu_count, bias, fast_skinny_gemm=_fast_skinny_gemm
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to ops.wvSplitKrc is missing the necessary reshape operations for the input and output tensors. The previous implementation correctly reshaped the input x to 2D before the kernel call and reshaped the output back to the original shape's prefix. The wvSplitKrc kernel expects a 2D input tensor (num_tokens, k) and returns a 2D output (num_tokens, m). If x is a 3D tensor (batch_size, seq_len, k), passing it directly will lead to incorrect dimensions being used inside the kernel and an incorrectly shaped output. This is a critical bug.

Suggested change
return ops.wvSplitKrc(
x, weight, cu_count, bias, fast_skinny_gemm=_fast_skinny_gemm
)
x_view = x.reshape(-1, x.size(-1))
out = ops.wvSplitKrc(
x_view, weight, cu_count, bias, fast_skinny_gemm=_fast_skinny_gemm
)
return out.reshape(*x.shape[:-1], weight.shape[0])

Comment on lines +1809 to +1817
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N, int GrpsShrB, int CHUNKK>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
__attribute__((amdgpu_waves_per_eu(1, 1)))
wvSplitKrc_fast_(const int actlN, const int K, const int Kap, const int M,
const int Bx, const int By, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
const scalar_t* __restrict__ BIAS, float* glbl,
scalar_t* C, const int CuCount) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a very large amount of code duplication between wvSplitKrc_fast_ and wvSplitKrc_deterministic_. The only substantial difference is the value of DTRMNSTC. This will make maintenance difficult, as any changes to the common logic will need to be applied in two places.

Consider refactoring this into a single templated implementation, for example:

template <bool DTRMNSTC>
__global__ void __launch_bounds__(...)
wvSplitKrc_impl_(...) {
  // ... common kernel logic ...
}

// Then define the two kernels as thin wrappers:
__global__ void __launch_bounds__(...)
wvSplitKrc_deterministic_(...) {
  wvSplitKrc_impl_<true>(...);
}

__global__ void __launch_bounds__(...)
wvSplitKrc_fast_(...) {
  wvSplitKrc_impl_<false>(...);
}

Since DTRMNSTC is a compile-time constant (constexpr), the compiler will optimize away the dead branches related to it, achieving the same performance benefit as two separate kernels without the code duplication.

and k % 8 == 0
and k > 512
and m % 16 == 0
and m * n <= 128 * 1024 # max reduce buffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The guard m * n <= 128 * 1024 is not sufficient to prevent a buffer overflow in the wvSplitKrc kernel. The memory required by the deterministic path depends on k_rnd, which is proportional to the K dimension. A very large K could cause an overflow even if m * n is within the limit.

A more robust check should also consider the K dimension. The logic to calculate the required buffer size is present in the wvSplitKrc host function and can be replicated here. Specifically, k_rnd is roughly (k + kFit - 1) // kFit, where kFit depends on chunkk, which in turn depends on CuNeeded. All this information is available in this function. The check should be something like m * n * k_rnd <= 8 * 1024 * 1024.

@AndreasKaratzas
Copy link
Collaborator Author

@amd-hhashemi let me know your thoughts on:

  • Duplication of kernels for clear scheduler path to save that one conditional of computation
  • Test coverage
  • Anything else that I might have done wrong 😅

Comment on lines +286 to +291
fast_skinny_gemm: bool = False
"""When enabled on ROCm, uses the non-deterministic atomicAdd reduction
path in skinny GEMM kernels (wvSplitKrc) for higher throughput. The
default (False) uses a deterministic store-then-reduce path that
guarantees bitwise reproducibility across runs with different batch
dimensions, at a small cost in LDS pressure and an extra sync."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should/could this belong to KernelConfig?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's a better place. Thanks for the suggestion :)

@mergify
Copy link

mergify bot commented Feb 25, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @AndreasKaratzas.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@gshtras
Copy link
Collaborator

gshtras commented Mar 9, 2026

Once available, could we have perf comparison for the 2 paths?

@AndreasKaratzas
Copy link
Collaborator Author

@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 15, 2026
@AndreasKaratzas AndreasKaratzas deleted the akaratza_skinny_gemm_determinism branch March 15, 2026 19:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase rocm Related to AMD ROCm v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants