diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index 1cea93bfeb..f4b9d9b305 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -216,6 +216,7 @@ def gemm_a8w8_blockscale_ck( x_scale: torch.Tensor, w_scale: torch.Tensor, Out: torch.Tensor, + splitK: int = 0, ) -> torch.Tensor: ... @@ -231,6 +232,7 @@ def gemm_a8w8_blockscale_cktile( w_scale: torch.Tensor, Out: torch.Tensor, isBpreshuffled: bool = False, + splitK: int = 0, ) -> torch.Tensor: ... @@ -685,10 +687,15 @@ def gemm_a8w8_blockscale( ) if config is not None: libtype = config["libtype"] + splitK = int(config.get("splitK", 0)) if libtype == "ck": - return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y) + return gemm_a8w8_blockscale_ck( + XQ, WQ, x_scale, w_scale, Y, splitK=splitK + ) elif libtype == "cktile": - return gemm_a8w8_blockscale_cktile(XQ, WQ, x_scale, w_scale, Y) + return gemm_a8w8_blockscale_cktile( + XQ, WQ, x_scale, w_scale, Y, splitK=splitK + ) else: assert 0, f"Unsupported libtype {libtype} for gemm_a8w8_blockscale" try: diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu index 6d99612be2..5449d5b7ee 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu @@ -1,7 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include #include #include @@ -15,7 +14,7 @@ #include "gemm_a8w8_blockscale_manifest.h" using BlockwiseKernel = std::function; + torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, int)>; using BlockwiseKernelMap = GemmDispatchMap; @@ -83,22 +82,28 @@ torch::Tensor gemm_a8w8_blockscale(torch::Tensor& XQ, torch::Tensor& WQ, torch::Tensor& x_scale, torch::Tensor& w_scale, - torch::Tensor& Y) + torch::Tensor& Y, + int splitK) { TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!"); TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); - int M = XQ.size(0); - int N = WQ.size(0); - int K = XQ.size(1); + TORCH_CHECK(splitK >= 0 && splitK <= 30, + "splitK must be in the range [0, 30], got ", + splitK); + + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + int KBatch = 1 << splitK; if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) { - blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y); + blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, KBatch); } else if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) { - blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y); + blockscale_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, KBatch); } else { diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu index d5cdf0d239..b6a9f3ca73 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu @@ -1,7 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -#include #include #include @@ -15,7 +14,7 @@ #include "gemm_a8w8_blockscale_cktile_manifest.h" using BlockwiseKernel = std::function; + torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, bool, int)>; using BlockwiseKernelMap = GemmDispatchMap; @@ -83,24 +82,30 @@ torch::Tensor gemm_a8w8_blockscale_cktile(torch::Tensor& XQ, torch::Tensor& x_scale, torch::Tensor& w_scale, torch::Tensor& Y, - bool preshuffleB) + bool preshuffleB, + int splitK) { TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!"); TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); - int M = XQ.size(0); - int N = WQ.size(0); - int K = XQ.size(1); + TORCH_CHECK(splitK >= 0 && splitK <= 30, + "splitK must be in the range [0, 30], got ", + splitK); + + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + int KBatch = 1 << splitK; if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) { blockscale_dispatch(M, N, K)( - XQ, WQ, x_scale, w_scale, Y, preshuffleB); + XQ, WQ, x_scale, w_scale, Y, preshuffleB, KBatch); } else if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) { blockscale_dispatch(M, N, K)( - XQ, WQ, x_scale, w_scale, Y, preshuffleB); + XQ, WQ, x_scale, w_scale, Y, preshuffleB, KBatch); } else { @@ -116,5 +121,5 @@ torch::Tensor gemm_a8w8_blockscale_bpreshuffle_cktile(torch::Tensor& XQ, torch::Tensor& Y, bool preshuffleB) { - return gemm_a8w8_blockscale_cktile(XQ, WQ, x_scale, w_scale, Y, preshuffleB); + return gemm_a8w8_blockscale_cktile(XQ, WQ, x_scale, w_scale, Y, preshuffleB, 0); } diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_tune.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_tune.cu index b1c9077d82..48e183809c 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_tune.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_tune.cu @@ -12,7 +12,7 @@ #include "gemm_a8w8_blockscale_cktile_manifest.h" using BlockwiseKernel = std::function; + torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, bool, int)>; // For certain high priority shapes, we directly use the best kernel rather // than use heuristics. @@ -73,12 +73,12 @@ torch::Tensor gemm_a8w8_blockscale_cktile_tune(torch::Tensor& XQ, if(Y.dtype() == at::ScalarType::BFloat16) { blockwise_dispatch_cktile(kernelId)( - XQ, WQ, x_scale, w_scale, Y, preshuffleB); + XQ, WQ, x_scale, w_scale, Y, preshuffleB, KBatch); } else if(Y.dtype() == at::ScalarType::Half) { blockwise_dispatch_cktile(kernelId)( - XQ, WQ, x_scale, w_scale, Y, preshuffleB); + XQ, WQ, x_scale, w_scale, Y, preshuffleB, KBatch); } else { diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu index c6620ab89b..0fea1c4a46 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu @@ -12,7 +12,7 @@ #include "gemm_a8w8_blockscale_manifest.h" using BlockwiseKernel = std::function; + torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, int)>; // For certain high priority shapes, we directly use the best kernel rather // than use heuristics. @@ -71,11 +71,11 @@ torch::Tensor gemm_a8w8_blockscale_tune(torch::Tensor& XQ, if(Y.dtype() == at::ScalarType::BFloat16) { - blockwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y); + blockwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, KBatch); } else if(Y.dtype() == at::ScalarType::Half) { - blockwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y); + blockwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, KBatch); } else { diff --git a/csrc/ck_gemm_a8w8_blockscale/gen_instances.py b/csrc/ck_gemm_a8w8_blockscale/gen_instances.py index c538659f96..fd6f799b3b 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gen_instances.py +++ b/csrc/ck_gemm_a8w8_blockscale/gen_instances.py @@ -115,7 +115,8 @@ def gen_ck_instance(self, k: KernelInstance): torch::Tensor &WQ, torch::Tensor &x_scale, torch::Tensor &w_scale, - torch::Tensor &Y + torch::Tensor &Y, + int KBatch ) {{ // Get M, N, K from input tensors. @@ -186,7 +187,7 @@ def gen_ck_instance(self, k: KernelInstance): ck::tensor_operation::device::GemmSpecialization::{{GemmSpec}}>; // Run kernel instance. - return gemm_a8w8_blockscale_impl(XQ, WQ, x_scale, w_scale, Y); + return gemm_a8w8_blockscale_impl(XQ, WQ, x_scale, w_scale, Y, KBatch); """ INSTANCE_IMPL_str = ( LEGACY_INSTANCE_IMPL.replace( @@ -238,7 +239,8 @@ def gen_ck_instance(self, k: KernelInstance): torch::Tensor &WQ, torch::Tensor &x_scale, torch::Tensor &w_scale, - torch::Tensor &Y + torch::Tensor &Y, + int KBatch ); """ @@ -312,7 +314,8 @@ def gen_manifest_head(self, kernels_dict): torch::Tensor &WQ, torch::Tensor &x_scale, torch::Tensor &w_scale, - torch::Tensor &Y); + torch::Tensor &Y, + int KBatch); """ MAINFEST_end = """ diff --git a/csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py b/csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py index 4cd9bc02a9..2326dea99e 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py +++ b/csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py @@ -74,7 +74,8 @@ def gen_cktile_instance(self, k: TileKernelInstance): torch::Tensor &x_scale, torch::Tensor &w_scale, torch::Tensor &Y, - bool preshuffleB + bool preshuffleB, + int k_batch ) {{ // Get M, N, K from input tensors. @@ -100,7 +101,7 @@ def gen_cktile_instance(self, k: TileKernelInstance): {str(k.AQRowMajor).lower()}>; // Run kernel instance. - return gemm_a8w8_blockscale_cktile_impl(XQ, WQ, x_scale, w_scale, Y, preshuffleB); + return gemm_a8w8_blockscale_cktile_impl(XQ, WQ, x_scale, w_scale, Y, preshuffleB, k_batch); """ TILE_INSTANCE_IMPL_str = TILE_INSTANCE_IMPL.replace( @@ -123,7 +124,8 @@ def gen_cktile_instance(self, k: TileKernelInstance): torch::Tensor &x_scale, torch::Tensor &w_scale, torch::Tensor &Y, - bool preshuffleB + bool preshuffleB, + int k_batch ); """ @@ -198,7 +200,8 @@ def gen_manifest_head(self, kernels_dict): torch::Tensor &x_scale, torch::Tensor &w_scale, torch::Tensor &Y, - bool preshuffleB); + bool preshuffleB, + int k_batch); """ MAINFEST_end = """ diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale.h b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale.h index fa909c12f5..2aa90db358 100644 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale.h +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale.h @@ -9,7 +9,8 @@ torch::Tensor gemm_a8w8_blockscale(torch::Tensor& XQ, torch::Tensor& WQ, torch::Tensor& x_scale, torch::Tensor& w_scale, - torch::Tensor& Y); + torch::Tensor& Y, + int splitK = 0); torch::Tensor gemm_a8w8_blockscale_tune(torch::Tensor& XQ, torch::Tensor& WQ, diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile.h b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile.h index 84455c88c6..a5e8475945 100644 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile.h +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile.h @@ -10,7 +10,8 @@ torch::Tensor gemm_a8w8_blockscale_cktile(torch::Tensor& XQ, torch::Tensor& x_scale, torch::Tensor& w_scale, torch::Tensor& Y, - bool preshuffleB); + bool preshuffleB, + int splitK = 0); torch::Tensor gemm_a8w8_blockscale_cktile_tune(torch::Tensor& XQ, torch::Tensor& WQ, diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh index 773c08311a..595a810056 100644 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh @@ -234,11 +234,6 @@ void TileGemmComputeImpl(ck_tile::QuantGemmHostArgs& args) const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 blocks = Kernel::BlockSize(); - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); @@ -283,7 +278,8 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ torch::Tensor& x_scale, torch::Tensor& w_scale, torch::Tensor& Y, - bool PreshuffleB) + bool PreshuffleB, + int k_batch = 1) { // check TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!"); @@ -372,8 +368,7 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ args.bq_ptr = w_scale.data_ptr(); args.c_ptr = Y.data_ptr(); - // split-k is not supported yet for tile quant gemm, set k_batch to 1 - args.k_batch = 1; + args.k_batch = k_batch; args.M = M; args.N = N; args.K = K; @@ -400,6 +395,14 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ args.stride_AQ = stride_AQ; args.stride_BQ = stride_BQ; + // Split-K uses atomic_add into C; zero the output buffer first. + // Use zero_() so all rows are cleared regardless of the leading-dimension + // stride (e.g. padded tensors produced by vLLM's _maybe_pad_fp8_weight). + if(k_batch > 1) + { + Y.zero_(); + } + // do tile GEMM if(PreshuffleB) { diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh index 84919bd65d..d7f0a43d18 100644 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh @@ -124,7 +124,8 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_impl(torch::Tensor& XQ, torch::Tensor& WQ, torch::Tensor& x_scale, torch::Tensor& w_scale, - torch::Tensor& Y) + torch::Tensor& Y, + int KBatch = 1) { int M = XQ.size(0); int N = WQ.size(0); @@ -160,6 +161,13 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_impl(torch::Tensor& XQ, b_element_op, cde_element_op); + TORCH_CHECK(KBatch >= 1, "KBatch must be >= 1, got ", KBatch); + + if(KBatch > 1) + { + device_gemm.SetKBatch(&argument, KBatch); + } + TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStream()}); diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index d2a2bd654f..08b06441cd 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -497,7 +497,8 @@ namespace py = pybind11; py::arg("WQ"), \ py::arg("x_scale"), \ py::arg("w_scale"), \ - py::arg("Out")); + py::arg("Out"), \ + py::arg("splitK") = 0); #define GEMM_A8W8_BLOCKSCALE_TUNE_PYBIND \ m.def("gemm_a8w8_blockscale_tune", \ @@ -520,7 +521,8 @@ namespace py = pybind11; py::arg("x_scale"), \ py::arg("w_scale"), \ py::arg("Out"), \ - py::arg("preshuffleB") = false); + py::arg("preshuffleB") = false, \ + py::arg("splitK") = 0); #define GEMM_A8W8_BLOCKSCALE_CKTILE_TUNE_PYBIND \ m.def("gemm_a8w8_blockscale_cktile_tune", \ diff --git a/op_tests/test_gemm_a8w8_blockscale.py b/op_tests/test_gemm_a8w8_blockscale.py index 54d6fd02b2..4899e3aa1a 100755 --- a/op_tests/test_gemm_a8w8_blockscale.py +++ b/op_tests/test_gemm_a8w8_blockscale.py @@ -13,6 +13,7 @@ import torch import torch.nn.functional as F from aiter import dtypes +from aiter.ops.gemm_op_a8w8 import gemm_a8w8_blockscale_ck, gemm_a8w8_blockscale_cktile from aiter.ops.shuffle import shuffle_weight from aiter.test_common import benchmark, checkAllclose, perftest from einops import rearrange @@ -127,6 +128,54 @@ def run_asm(x, weight, x_scale, w_scale, dtype=dtypes.bf16, kernel_name=None): return aiter.gemm_a8w8_blockscale_bpreshuffle_asm(x, weight, out, x_scale, w_scale) +def test_splitk_correctness(m=4, n=2112, k=7168, dtype=dtypes.bf16, splitK=1): + """Verify that splitK > 0 produces the same output as splitK=0 (within fp tolerance). + + split-K accumulates partial tiles via atomic_add, which changes the floating-point + reduction order. We therefore use a relaxed tolerance that matches the cumulative + rounding error introduced by K-splitting. + """ + block_shape_n, block_shape_k = block_shape + scale_n = (n + block_shape_n - 1) // block_shape_n + scale_k = (k + block_shape_k - 1) // block_shape_k + + x = (torch.rand((m, k), dtype=dtypes.fp32, device="cuda") / 10).to(dtypes.fp8) + weight = (torch.rand((n, k), dtype=dtypes.fp32, device="cuda") / 10).to(dtypes.fp8) + x_scale = torch.rand([m, scale_k], dtype=dtypes.fp32, device="cuda") + w_scale = torch.rand([scale_n, scale_k], dtype=dtypes.fp32, device="cuda") + + # CK path (no preshuffle): compare splitK=0 vs splitK>0 + Y_base = torch.empty((m, n), dtype=dtype, device="cuda") + Y_split = torch.empty((m, n), dtype=dtype, device="cuda") + gemm_a8w8_blockscale_ck(x, weight, x_scale, w_scale, Y_base, splitK=0) + gemm_a8w8_blockscale_ck(x, weight, x_scale, w_scale, Y_split, splitK=splitK) + ck_err = checkAllclose( + Y_base, Y_split, msg=f"ck splitK={splitK} vs splitK=0", rtol=1e-2, atol=1e-2 + ) + + # CKTile path (no preshuffle): compare splitK=0 vs splitK>0 + Y_base_tile = torch.empty((m, n), dtype=dtype, device="cuda") + Y_split_tile = torch.empty((m, n), dtype=dtype, device="cuda") + gemm_a8w8_blockscale_cktile( + x, weight, x_scale, w_scale, Y_base_tile, False, splitK=0 + ) + gemm_a8w8_blockscale_cktile( + x, weight, x_scale, w_scale, Y_split_tile, False, splitK=splitK + ) + cktile_err = checkAllclose( + Y_base_tile, + Y_split_tile, + msg=f"cktile splitK={splitK} vs splitK=0", + rtol=1e-2, + atol=1e-2, + ) + + print( + f"test_splitk_correctness(m={m}, n={n}, k={k}, splitK={splitK}): " + f"ck_err={ck_err:.4g}, cktile_err={cktile_err:.4g}" + ) + + parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, description="config input of test", @@ -301,6 +350,12 @@ def run_asm(x, weight, x_scale, w_scale, dtype=dtypes.bf16, kernel_name=None): df_md = df.to_markdown(index=False) aiter.logger.info("gemm_a8w8_blockscale summary (markdown):\n%s", df_md) +# Correctness check: verify split-K produces matching results +print("\nRunning split-K correctness checks ...") +for splitK in [1, 2]: + test_splitk_correctness(m=4, n=512, k=16384, splitK=splitK) + +# Save results from benchmarks if args.output: os.makedirs(args.output, exist_ok=True) if args.csv: