From f1f94cf0cde84ceda4936247b28418385e92afcb Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Mon, 30 Mar 2026 07:43:37 -0400 Subject: [PATCH 1/5] Enable split-K for block-scale A8W8 CK and CKTile GEMMs Propagate the splitK parameter (as KBatch = 2^splitK) through the block-scale GEMM kernel infrastructure so that the tuning scripts can sweep split-K values to improve occupancy on small-M shapes. CK path: add KBatch parameter to gemm_a8w8_blockscale_impl and call SetKBatch on the device argument. The CK invoker handles output zeroing and atomic accumulation internally. CKTile path: add k_batch parameter to gemm_a8w8_blockscale_cktile_impl, remove the "split-k is not supported yet" runtime guard, and add hipMemsetAsync to zero the output buffer before atomic accumulation. Non-tune entry points pass KBatch=1 (no split-K) to preserve existing behavior. Code generation scripts (gen_instances.py, gen_instances_cktile.py) updated to include the new parameter in generated wrappers and manifests. Made-with: Cursor --- .../gemm_a8w8_blockscale.cu | 6 +++--- .../gemm_a8w8_blockscale_cktile.cu | 6 +++--- .../gemm_a8w8_blockscale_cktile_tune.cu | 6 +++--- .../gemm_a8w8_blockscale_tune.cu | 6 +++--- csrc/ck_gemm_a8w8_blockscale/gen_instances.py | 11 ++++++---- .../gen_instances_cktile.py | 11 ++++++---- .../gemm_a8w8_blockscale_cktile_common.cuh | 20 +++++++++++-------- .../include/gemm_a8w8_blockscale_common.cuh | 8 +++++++- 8 files changed, 45 insertions(+), 29 deletions(-) diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu index daecea9ebe..4ffffd4d2f 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu @@ -14,7 +14,7 @@ #include "gemm_a8w8_blockscale_manifest.h" using BlockwiseKernel = std::function; + torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, int)>; // Define a custom hash function for std::tuple struct IntTupleHash @@ -103,11 +103,11 @@ torch::Tensor gemm_a8w8_blockscale(torch::Tensor& XQ, if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) { - blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y); + blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, 1); } else if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) { - blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y); + blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, 1); } else { diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu index 1a809a6bf4..0a2e4162af 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu @@ -14,7 +14,7 @@ #include "gemm_a8w8_blockscale_cktile_manifest.h" using BlockwiseKernel = std::function; + torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, bool, int)>; // Define a custom hash function for std::tuple struct IntTupleHash @@ -104,12 +104,12 @@ torch::Tensor gemm_a8w8_blockscale_cktile(torch::Tensor& XQ, if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) { blockscale_dispatch(M, N, K)( - XQ, WQ, x_scale, w_scale, Y, preshuffleB); + XQ, WQ, x_scale, w_scale, Y, preshuffleB, 1); } else if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) { blockscale_dispatch(M, N, K)( - XQ, WQ, x_scale, w_scale, Y, preshuffleB); + XQ, WQ, x_scale, w_scale, Y, preshuffleB, 1); } else { diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_tune.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_tune.cu index b1c9077d82..48e183809c 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_tune.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_tune.cu @@ -12,7 +12,7 @@ #include "gemm_a8w8_blockscale_cktile_manifest.h" using BlockwiseKernel = std::function; + torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, bool, int)>; // For certain high priority shapes, we directly use the best kernel rather // than use heuristics. @@ -73,12 +73,12 @@ torch::Tensor gemm_a8w8_blockscale_cktile_tune(torch::Tensor& XQ, if(Y.dtype() == at::ScalarType::BFloat16) { blockwise_dispatch_cktile(kernelId)( - XQ, WQ, x_scale, w_scale, Y, preshuffleB); + XQ, WQ, x_scale, w_scale, Y, preshuffleB, KBatch); } else if(Y.dtype() == at::ScalarType::Half) { blockwise_dispatch_cktile(kernelId)( - XQ, WQ, x_scale, w_scale, Y, preshuffleB); + XQ, WQ, x_scale, w_scale, Y, preshuffleB, KBatch); } else { diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu index c6620ab89b..0fea1c4a46 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu @@ -12,7 +12,7 @@ #include "gemm_a8w8_blockscale_manifest.h" using BlockwiseKernel = std::function; + torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, int)>; // For certain high priority shapes, we directly use the best kernel rather // than use heuristics. @@ -71,11 +71,11 @@ torch::Tensor gemm_a8w8_blockscale_tune(torch::Tensor& XQ, if(Y.dtype() == at::ScalarType::BFloat16) { - blockwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y); + blockwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, KBatch); } else if(Y.dtype() == at::ScalarType::Half) { - blockwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y); + blockwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, KBatch); } else { diff --git a/csrc/ck_gemm_a8w8_blockscale/gen_instances.py b/csrc/ck_gemm_a8w8_blockscale/gen_instances.py index 134d1313bd..ab2518f84c 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gen_instances.py +++ b/csrc/ck_gemm_a8w8_blockscale/gen_instances.py @@ -123,7 +123,8 @@ def gen_ck_instance(self, k: KernelInstance): torch::Tensor &WQ, torch::Tensor &x_scale, torch::Tensor &w_scale, - torch::Tensor &Y + torch::Tensor &Y, + int KBatch ) {{ // Get M, N, K from input tensors. @@ -194,7 +195,7 @@ def gen_ck_instance(self, k: KernelInstance): ck::tensor_operation::device::GemmSpecialization::{{GemmSpec}}>; // Run kernel instance. - return gemm_a8w8_blockscale_impl(XQ, WQ, x_scale, w_scale, Y); + return gemm_a8w8_blockscale_impl(XQ, WQ, x_scale, w_scale, Y, KBatch); """ INSTANCE_IMPL_str = ( LEGACY_INSTANCE_IMPL.replace( @@ -246,7 +247,8 @@ def gen_ck_instance(self, k: KernelInstance): torch::Tensor &WQ, torch::Tensor &x_scale, torch::Tensor &w_scale, - torch::Tensor &Y + torch::Tensor &Y, + int KBatch ); """ @@ -330,7 +332,8 @@ def gen_manifest_head(self, kernels_dict): torch::Tensor &WQ, torch::Tensor &x_scale, torch::Tensor &w_scale, - torch::Tensor &Y); + torch::Tensor &Y, + int KBatch); """ MAINFEST_end = """ diff --git a/csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py b/csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py index f54046a677..a13f593c5c 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py +++ b/csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py @@ -99,7 +99,8 @@ def gen_cktile_instance(self, k: TileKernelInstance): torch::Tensor &x_scale, torch::Tensor &w_scale, torch::Tensor &Y, - bool preshuffleB + bool preshuffleB, + int k_batch ) {{ // Get M, N, K from input tensors. @@ -124,7 +125,7 @@ def gen_cktile_instance(self, k: TileKernelInstance): {k.BlockPerCu}>; // Run kernel instance. - return gemm_a8w8_blockscale_cktile_impl(XQ, WQ, x_scale, w_scale, Y, preshuffleB); + return gemm_a8w8_blockscale_cktile_impl(XQ, WQ, x_scale, w_scale, Y, preshuffleB, k_batch); """ TILE_INSTANCE_IMPL_str = TILE_INSTANCE_IMPL.replace( @@ -147,7 +148,8 @@ def gen_cktile_instance(self, k: TileKernelInstance): torch::Tensor &x_scale, torch::Tensor &w_scale, torch::Tensor &Y, - bool preshuffleB + bool preshuffleB, + int k_batch ); """ @@ -232,7 +234,8 @@ def gen_manifest_head(self, kernels_dict): torch::Tensor &x_scale, torch::Tensor &w_scale, torch::Tensor &Y, - bool preshuffleB); + bool preshuffleB, + int k_batch); """ MAINFEST_end = """ diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh index 4aa313125c..2ba1ffc0d5 100644 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh @@ -226,11 +226,6 @@ void TileGemmComputeImpl(ck_tile::QuantGemmHostArgs& args) const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 blocks = Kernel::BlockSize(); - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); @@ -275,7 +270,8 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ torch::Tensor& x_scale, torch::Tensor& w_scale, torch::Tensor& Y, - bool PreshuffleB) + bool PreshuffleB, + int k_batch = 1) { // check TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!"); @@ -329,8 +325,7 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ args.bq_ptr = w_scale.data_ptr(); args.c_ptr = Y.data_ptr(); - // split-k is not supported yet for tile quant gemm, set k_batch to 1 - args.k_batch = 1; + args.k_batch = k_batch; args.M = M; args.N = N; args.K = K; @@ -358,6 +353,15 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ args.stride_AQ = stride_AQ; args.stride_BQ = stride_BQ; + // Split-K uses atomic_add into C; zero the output buffer first. + if(k_batch > 1) + { + hipMemsetAsync(Y.data_ptr(), + 0, + M * N * sizeof(OutDataType), + at::hip::getCurrentHIPStream()); + } + // do tile GEMM if(PreshuffleB) { diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh index 84919bd65d..f81247d206 100644 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh @@ -124,7 +124,8 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_impl(torch::Tensor& XQ, torch::Tensor& WQ, torch::Tensor& x_scale, torch::Tensor& w_scale, - torch::Tensor& Y) + torch::Tensor& Y, + int KBatch = 1) { int M = XQ.size(0); int N = WQ.size(0); @@ -160,6 +161,11 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_impl(torch::Tensor& XQ, b_element_op, cde_element_op); + if(KBatch > 1) + { + device_gemm.SetKBatch(&argument, KBatch); + } + TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStream()}); From ec422c3e56eac696af66419319d6897ebb666b16 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Mon, 30 Mar 2026 08:39:40 -0400 Subject: [PATCH 2/5] Wire splitK from tuning CSV through production blockscale GEMM dispatch The tuning infrastructure already sweeps splitK and writes it to the CSV, but the production dispatch ignored it and hardcoded KBatch=1. Add splitK as a runtime parameter to the non-tune entry points so tuned split-K values are used without compiling the full _tune instance set. Made-with: Cursor --- aiter/ops/gemm_op_a8w8.py | 11 +++++++++-- csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu | 8 +++++--- .../gemm_a8w8_blockscale_cktile.cu | 10 ++++++---- .../include/gemm_a8w8_blockscale.h | 3 ++- .../include/gemm_a8w8_blockscale_cktile.h | 3 ++- csrc/include/rocm_ops.hpp | 6 ++++-- 6 files changed, 28 insertions(+), 13 deletions(-) diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index 898d3f47d2..424006a172 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -162,6 +162,7 @@ def gemm_a8w8_blockscale_ck( x_scale: torch.Tensor, w_scale: torch.Tensor, Out: torch.Tensor, + splitK: int = 0, ) -> torch.Tensor: ... @@ -177,6 +178,7 @@ def gemm_a8w8_blockscale_cktile( w_scale: torch.Tensor, Out: torch.Tensor, isBpreshuffled: bool = False, + splitK: int = 0, ) -> torch.Tensor: ... @@ -610,10 +612,15 @@ def gemm_a8w8_blockscale( ) if config is not None: libtype = config["libtype"] + splitK = int(config.get("splitK", 0)) if libtype == "ck": - return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y) + return gemm_a8w8_blockscale_ck( + XQ, WQ, x_scale, w_scale, Y, splitK=splitK + ) elif libtype == "cktile": - return gemm_a8w8_blockscale_cktile(XQ, WQ, x_scale, w_scale, Y) + return gemm_a8w8_blockscale_cktile( + XQ, WQ, x_scale, w_scale, Y, splitK=splitK + ) else: assert 0, f"Unsupported libtype {libtype} for gemm_a8w8_blockscale" return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y) diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu index 4ffffd4d2f..61c54355a1 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu @@ -92,7 +92,8 @@ torch::Tensor gemm_a8w8_blockscale(torch::Tensor& XQ, torch::Tensor& WQ, torch::Tensor& x_scale, torch::Tensor& w_scale, - torch::Tensor& Y) + torch::Tensor& Y, + int splitK) { TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!"); TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); @@ -100,14 +101,15 @@ torch::Tensor gemm_a8w8_blockscale(torch::Tensor& XQ, int M = XQ.size(0); int N = WQ.size(0); int K = XQ.size(1); + int KBatch = static_cast(std::pow(2, splitK)); if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) { - blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, 1); + blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, KBatch); } else if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) { - blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, 1); + blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, KBatch); } else { diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu index 0a2e4162af..37600de72a 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu @@ -92,7 +92,8 @@ torch::Tensor gemm_a8w8_blockscale_cktile(torch::Tensor& XQ, torch::Tensor& x_scale, torch::Tensor& w_scale, torch::Tensor& Y, - bool preshuffleB) + bool preshuffleB, + int splitK) { TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!"); TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); @@ -100,16 +101,17 @@ torch::Tensor gemm_a8w8_blockscale_cktile(torch::Tensor& XQ, int M = XQ.size(0); int N = WQ.size(0); int K = XQ.size(1); + int KBatch = static_cast(std::pow(2, splitK)); if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) { blockscale_dispatch(M, N, K)( - XQ, WQ, x_scale, w_scale, Y, preshuffleB, 1); + XQ, WQ, x_scale, w_scale, Y, preshuffleB, KBatch); } else if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) { blockscale_dispatch(M, N, K)( - XQ, WQ, x_scale, w_scale, Y, preshuffleB, 1); + XQ, WQ, x_scale, w_scale, Y, preshuffleB, KBatch); } else { @@ -125,5 +127,5 @@ torch::Tensor gemm_a8w8_blockscale_bpreshuffle_cktile(torch::Tensor& XQ, torch::Tensor& Y, bool preshuffleB) { - return gemm_a8w8_blockscale_cktile(XQ, WQ, x_scale, w_scale, Y, preshuffleB); + return gemm_a8w8_blockscale_cktile(XQ, WQ, x_scale, w_scale, Y, preshuffleB, 0); } diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale.h b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale.h index fa909c12f5..2aa90db358 100644 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale.h +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale.h @@ -9,7 +9,8 @@ torch::Tensor gemm_a8w8_blockscale(torch::Tensor& XQ, torch::Tensor& WQ, torch::Tensor& x_scale, torch::Tensor& w_scale, - torch::Tensor& Y); + torch::Tensor& Y, + int splitK = 0); torch::Tensor gemm_a8w8_blockscale_tune(torch::Tensor& XQ, torch::Tensor& WQ, diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile.h b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile.h index 84455c88c6..a5e8475945 100644 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile.h +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile.h @@ -10,7 +10,8 @@ torch::Tensor gemm_a8w8_blockscale_cktile(torch::Tensor& XQ, torch::Tensor& x_scale, torch::Tensor& w_scale, torch::Tensor& Y, - bool preshuffleB); + bool preshuffleB, + int splitK = 0); torch::Tensor gemm_a8w8_blockscale_cktile_tune(torch::Tensor& XQ, torch::Tensor& WQ, diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 09efe6fdd2..c41217d2a6 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -471,7 +471,8 @@ namespace py = pybind11; py::arg("WQ"), \ py::arg("x_scale"), \ py::arg("w_scale"), \ - py::arg("Out")); + py::arg("Out"), \ + py::arg("splitK") = 0); #define GEMM_A8W8_BLOCKSCALE_TUNE_PYBIND \ m.def("gemm_a8w8_blockscale_tune", \ @@ -494,7 +495,8 @@ namespace py = pybind11; py::arg("x_scale"), \ py::arg("w_scale"), \ py::arg("Out"), \ - py::arg("preshuffleB") = false); + py::arg("preshuffleB") = false, \ + py::arg("splitK") = 0); #define GEMM_A8W8_BLOCKSCALE_CKTILE_TUNE_PYBIND \ m.def("gemm_a8w8_blockscale_cktile_tune", \ From 2bf04dddac64fa5aeeb25cf71b432f485c4535ef Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 8 Apr 2026 15:04:04 +0000 Subject: [PATCH 3/5] Address PR review feedback: validate splitK, fix hipMemset stride issue, add correctness test Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e3b37b0f-e151-4935-ad89-fd72436d41e2 Co-authored-by: samremes <181322991+samremes@users.noreply.github.com> --- .../gemm_a8w8_blockscale.cu | 13 ++++-- .../gemm_a8w8_blockscale_cktile.cu | 13 ++++-- .../gemm_a8w8_blockscale_cktile_common.cuh | 7 ++- .../include/gemm_a8w8_blockscale_common.cuh | 2 + op_tests/test_gemm_a8w8_blockscale.py | 46 +++++++++++++++++++ 5 files changed, 67 insertions(+), 14 deletions(-) diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu index 61c54355a1..ffc3695575 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu @@ -1,7 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include #include #include @@ -98,10 +97,14 @@ torch::Tensor gemm_a8w8_blockscale(torch::Tensor& XQ, TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!"); TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); - int M = XQ.size(0); - int N = WQ.size(0); - int K = XQ.size(1); - int KBatch = static_cast(std::pow(2, splitK)); + TORCH_CHECK(splitK >= 0 && splitK <= 30, + "splitK must be in the range [0, 30], got ", + splitK); + + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + int KBatch = 1 << splitK; if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) { diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu index 37600de72a..91e1e7e0e5 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu @@ -1,7 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -#include #include #include @@ -98,10 +97,14 @@ torch::Tensor gemm_a8w8_blockscale_cktile(torch::Tensor& XQ, TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!"); TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); - int M = XQ.size(0); - int N = WQ.size(0); - int K = XQ.size(1); - int KBatch = static_cast(std::pow(2, splitK)); + TORCH_CHECK(splitK >= 0 && splitK <= 30, + "splitK must be in the range [0, 30], got ", + splitK); + + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + int KBatch = 1 << splitK; if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) { diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh index 07146f322d..99657fb56a 100644 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh @@ -364,12 +364,11 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ args.stride_BQ = stride_BQ; // Split-K uses atomic_add into C; zero the output buffer first. + // Use zero_() so all rows are cleared regardless of the leading-dimension + // stride (e.g. padded tensors produced by vLLM's _maybe_pad_fp8_weight). if(k_batch > 1) { - hipMemsetAsync(Y.data_ptr(), - 0, - M * N * sizeof(OutDataType), - at::hip::getCurrentHIPStream()); + Y.zero_(); } // do tile GEMM diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh index f81247d206..d7f0a43d18 100644 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh @@ -161,6 +161,8 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_impl(torch::Tensor& XQ, b_element_op, cde_element_op); + TORCH_CHECK(KBatch >= 1, "KBatch must be >= 1, got ", KBatch); + if(KBatch > 1) { device_gemm.SetKBatch(&argument, KBatch); diff --git a/op_tests/test_gemm_a8w8_blockscale.py b/op_tests/test_gemm_a8w8_blockscale.py index 6604bdb181..76b6e5ddcb 100755 --- a/op_tests/test_gemm_a8w8_blockscale.py +++ b/op_tests/test_gemm_a8w8_blockscale.py @@ -13,6 +13,7 @@ import torch import torch.nn.functional as F from aiter import dtypes +from aiter.ops.gemm_op_a8w8 import gemm_a8w8_blockscale_ck, gemm_a8w8_blockscale_cktile from aiter.ops.shuffle import shuffle_weight from aiter.test_common import benchmark, checkAllclose, perftest from einops import rearrange @@ -126,6 +127,44 @@ def run_asm(x, weight, x_scale, w_scale, dtype=dtypes.bf16, kernel_name=None): return aiter.gemm_a8w8_blockscale_bpreshuffle_asm(x, weight, out, x_scale, w_scale) +def test_splitk_correctness(m=4, n=2112, k=7168, dtype=dtypes.bf16, splitK=1): + """Verify that splitK > 0 produces the same output as splitK=0 (within fp tolerance). + + split-K accumulates partial tiles via atomic_add, which changes the floating-point + reduction order. We therefore use a relaxed tolerance that matches the cumulative + rounding error introduced by K-splitting. + """ + block_shape_n, block_shape_k = block_shape + scale_n = (n + block_shape_n - 1) // block_shape_n + scale_k = (k + block_shape_k - 1) // block_shape_k + + x = (torch.rand((m, k), dtype=dtypes.fp32, device="cuda") / 10).to(dtypes.fp8) + weight = (torch.rand((n, k), dtype=dtypes.fp32, device="cuda") / 10).to(dtypes.fp8) + x_scale = torch.rand([m, scale_k], dtype=dtypes.fp32, device="cuda") + w_scale = torch.rand([scale_n, scale_k], dtype=dtypes.fp32, device="cuda") + + # CK path (no preshuffle): compare splitK=0 vs splitK>0 + Y_base = torch.empty((m, n), dtype=dtype, device="cuda") + Y_split = torch.empty((m, n), dtype=dtype, device="cuda") + gemm_a8w8_blockscale_ck(x, weight, x_scale, w_scale, Y_base, splitK=0) + gemm_a8w8_blockscale_ck(x, weight, x_scale, w_scale, Y_split, splitK=splitK) + ck_err = checkAllclose(Y_base, Y_split, msg=f"ck splitK={splitK} vs splitK=0", rtol=1e-2, atol=1e-2) + + # CKTile path (no preshuffle): compare splitK=0 vs splitK>0 + Y_base_tile = torch.empty((m, n), dtype=dtype, device="cuda") + Y_split_tile = torch.empty((m, n), dtype=dtype, device="cuda") + gemm_a8w8_blockscale_cktile(x, weight, x_scale, w_scale, Y_base_tile, False, splitK=0) + gemm_a8w8_blockscale_cktile(x, weight, x_scale, w_scale, Y_split_tile, False, splitK=splitK) + cktile_err = checkAllclose( + Y_base_tile, Y_split_tile, msg=f"cktile splitK={splitK} vs splitK=0", rtol=1e-2, atol=1e-2 + ) + + print( + f"test_splitk_correctness(m={m}, n={n}, k={k}, splitK={splitK}): " + f"ck_err={ck_err:.4g}, cktile_err={cktile_err:.4g}" + ) + + parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, description="config input of test", @@ -256,3 +295,10 @@ def run_asm(x, weight, x_scale, w_scale, dtype=dtypes.bf16, kernel_name=None): df_md = df.to_markdown(index=False) aiter.logger.info("gemm_a8w8_blockscale summary (markdown):\n%s", df_md) + +# Correctness check: verify split-K produces matching results +print("\nRunning split-K correctness checks ...") +for splitK in [1, 2]: + test_splitk_correctness(m=4, n=2112, k=7168, splitK=splitK) + test_splitk_correctness(m=1, n=3072, k=1536, splitK=splitK) +print("split-K correctness checks passed.") From 7e319e58a9bc9d39d5c6b1ecd587944bde3cba6a Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Thu, 9 Apr 2026 13:49:50 -0500 Subject: [PATCH 4/5] black format --- op_tests/test_gemm_a8w8_blockscale.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/op_tests/test_gemm_a8w8_blockscale.py b/op_tests/test_gemm_a8w8_blockscale.py index 2b2adbb254..c6581693bc 100755 --- a/op_tests/test_gemm_a8w8_blockscale.py +++ b/op_tests/test_gemm_a8w8_blockscale.py @@ -148,15 +148,25 @@ def test_splitk_correctness(m=4, n=2112, k=7168, dtype=dtypes.bf16, splitK=1): Y_split = torch.empty((m, n), dtype=dtype, device="cuda") gemm_a8w8_blockscale_ck(x, weight, x_scale, w_scale, Y_base, splitK=0) gemm_a8w8_blockscale_ck(x, weight, x_scale, w_scale, Y_split, splitK=splitK) - ck_err = checkAllclose(Y_base, Y_split, msg=f"ck splitK={splitK} vs splitK=0", rtol=1e-2, atol=1e-2) + ck_err = checkAllclose( + Y_base, Y_split, msg=f"ck splitK={splitK} vs splitK=0", rtol=1e-2, atol=1e-2 + ) # CKTile path (no preshuffle): compare splitK=0 vs splitK>0 Y_base_tile = torch.empty((m, n), dtype=dtype, device="cuda") Y_split_tile = torch.empty((m, n), dtype=dtype, device="cuda") - gemm_a8w8_blockscale_cktile(x, weight, x_scale, w_scale, Y_base_tile, False, splitK=0) - gemm_a8w8_blockscale_cktile(x, weight, x_scale, w_scale, Y_split_tile, False, splitK=splitK) + gemm_a8w8_blockscale_cktile( + x, weight, x_scale, w_scale, Y_base_tile, False, splitK=0 + ) + gemm_a8w8_blockscale_cktile( + x, weight, x_scale, w_scale, Y_split_tile, False, splitK=splitK + ) cktile_err = checkAllclose( - Y_base_tile, Y_split_tile, msg=f"cktile splitK={splitK} vs splitK=0", rtol=1e-2, atol=1e-2 + Y_base_tile, + Y_split_tile, + msg=f"cktile splitK={splitK} vs splitK=0", + rtol=1e-2, + atol=1e-2, ) print( From 118099e39b7a46f345ef3a44f024e05e140aef3e Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Fri, 10 Apr 2026 05:33:15 -0500 Subject: [PATCH 5/5] fix splitk test dimensions --- op_tests/test_gemm_a8w8_blockscale.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/op_tests/test_gemm_a8w8_blockscale.py b/op_tests/test_gemm_a8w8_blockscale.py index c6581693bc..6c1358e729 100755 --- a/op_tests/test_gemm_a8w8_blockscale.py +++ b/op_tests/test_gemm_a8w8_blockscale.py @@ -309,6 +309,4 @@ def test_splitk_correctness(m=4, n=2112, k=7168, dtype=dtypes.bf16, splitK=1): # Correctness check: verify split-K produces matching results print("\nRunning split-K correctness checks ...") for splitK in [1, 2]: - test_splitk_correctness(m=4, n=2112, k=7168, splitK=splitK) - test_splitk_correctness(m=1, n=3072, k=1536, splitK=splitK) -print("split-K correctness checks passed.") + test_splitk_correctness(m=4, n=512, k=16384, splitK=splitK)