diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_instance.py b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_instance.py index 7969c33532..56a7b4c12a 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_instance.py +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_instance.py @@ -38,47 +38,56 @@ class TileKernelInstance: BlockPerCu: int # 1..BLOCK_PER_CU_MAX + # When True, 8-warp kernels read x_scale in row-major layout natively, + # skipping the host-side transpose. + AQRowMajor: bool = False + + @property + def is_eight_warp(self) -> bool: + return self.M_Warp * self.N_Warp * self.K_Warp == 8 and self.K_Warp_Tile == 128 + @property def name(self) -> str: """ Generate a unique name for the kernel instance based on its parameters. """ - return ("_").join( - [ - "a8w8_blockscale_cktile", - ("x").join( - map( - lambda x: str(x), - [self.M_Tile, self.N_Tile, self.K_Tile], - ) - ), - ("x").join( - map( - lambda x: str(x), - [self.M_Warp, self.N_Warp, self.K_Warp], - ) - ), - ("x").join( - map( - lambda x: str(x), - [self.M_Warp_Tile, self.N_Warp_Tile, self.K_Warp_Tile], - ) - ), - self.Scheduler.lower(), - ("x").join( - map( - lambda x: str(int(x)), - [ - self.TiledMMAPermuteN, - self.TransposeC, - self.UsePersistentKernel, - ], - ) - ), - str(self.BlockPerCu), - ] - ) + parts = [ + "a8w8_blockscale_cktile", + ("x").join( + map( + lambda x: str(x), + [self.M_Tile, self.N_Tile, self.K_Tile], + ) + ), + ("x").join( + map( + lambda x: str(x), + [self.M_Warp, self.N_Warp, self.K_Warp], + ) + ), + ("x").join( + map( + lambda x: str(x), + [self.M_Warp_Tile, self.N_Warp_Tile, self.K_Warp_Tile], + ) + ), + self.Scheduler.lower(), + ("x").join( + map( + lambda x: str(int(x)), + [ + self.TiledMMAPermuteN, + self.TransposeC, + self.UsePersistentKernel, + ], + ) + ), + str(self.BlockPerCu), + ] + if self.AQRowMajor: + parts.append("aqrm") + return "_".join(parts) BLOCK_PER_CU_MAX = 4 @@ -131,6 +140,8 @@ def expand_blockpercu(base_dict, max_bpc=BLOCK_PER_CU_MAX, field_name="BlockPerC 9: TileKernelInstance( 128, 128, 128, 1, 4, 1, 16, 16, 128, "Intrawave", False, True, False, 1 ), 10: TileKernelInstance( 128, 128, 128, 2, 2, 1, 16, 16, 128, "Intrawave", False, True, False, 2 ), 11: TileKernelInstance( 192, 256, 128, 4, 2, 1, 16, 16, 128, "Intrawave", False, True, False, 1 ), + # 8-warp kernel (4x2x1=8) with AQRowMajor=True: skip host-side x_scale transpose + 12: TileKernelInstance( 192, 256, 128, 4, 2, 1, 16, 16, 128, "Intrawave", False, True, False, 1, AQRowMajor=True), } default_kernels_cktile_dict = { diff --git a/csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py b/csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py index f54046a677..9be6fa337a 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py +++ b/csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py @@ -121,7 +121,8 @@ def gen_cktile_instance(self, k: TileKernelInstance): {str(k.TransposeC).lower()}, {str(k.UsePersistentKernel).lower()}, ck_tile::GemmPipelineScheduler::{k.Scheduler}, - {k.BlockPerCu}>; + {k.BlockPerCu}, + {str(k.AQRowMajor).lower()}>; // Run kernel instance. return gemm_a8w8_blockscale_cktile_impl(XQ, WQ, x_scale, w_scale, Y, preshuffleB); diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh index a1003cd47d..0c254d6422 100644 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh @@ -60,7 +60,8 @@ template + int BlockPerCu = 1, + bool AQRowMajor = false> struct CreateTileGemmConfig { static constexpr ck_tile::index_t M_Tile_v = M_Tile; @@ -77,6 +78,7 @@ struct CreateTileGemmConfig static constexpr bool UsePersistentKernel_v = UsePersistentKernel; static constexpr ck_tile::GemmPipelineScheduler Scheduler_v = Scheduler; static constexpr int BlockPerCu_v = BlockPerCu; + static constexpr bool AQRowMajor_v = AQRowMajor; }; template + int BlockPerCu = 1, + bool AQRowMajor = false> using TileGemmConfig = CreateTileGemmConfig; + BlockPerCu, + AQRowMajor>; template , @@ -145,7 +152,7 @@ void TileGemmComputeImpl(ck_tile::QuantGemmHostArgs& args) BLayout, CLayout, QuantMode, - std::conditional_t, + std::conditional_t, BQLayout, transpose_c, UseDoubleSmemBuffer>; @@ -308,7 +315,15 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ const int N = WQ.size(0); const int K = XQ.size(1); - const bool eight_waves = + // Whether this kernel configuration uses column-major AQ layout, + // requiring a host-side transpose of x_scale. + constexpr bool aq_col_major = + BQuantGroupSize::kN == 128 && + (GemmInstance::M_Warp_v * GemmInstance::N_Warp_v * GemmInstance::K_Warp_v == 8) && + GemmInstance::K_Warp_Tile_v == 128 && + !GemmInstance::AQRowMajor_v; + + constexpr bool eight_waves = BQuantGroupSize::kN == 128 && (GemmInstance::M_Warp_v * GemmInstance::N_Warp_v * GemmInstance::K_Warp_v == 8) && GemmInstance::K_Warp_Tile_v == 128; @@ -321,18 +336,34 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ // through the async kernel launch. torch::Tensor x_scale_t; - if(eight_waves && !PreshuffleB) + if constexpr(aq_col_major) { - x_scale_t = x_scale.transpose(0, 1).contiguous().view(x_scale.sizes()); - args.aq_ptr = x_scale_t.data_ptr(); + // 8-warp ColumnMajor AQ: transpose x_scale to col-major + if(!PreshuffleB) + { + x_scale_t = x_scale.transpose(0, 1).contiguous().view(x_scale.sizes()); + args.aq_ptr = x_scale_t.data_ptr(); + } + else + { + args.aq_ptr = x_scale.data_ptr(); + } } - else if(!eight_waves && PreshuffleB) + else if constexpr(!eight_waves) { - x_scale_t = x_scale.view({x_scale.size(1), x_scale.size(0)}).transpose(0, 1).contiguous(); - args.aq_ptr = x_scale_t.data_ptr(); + if(PreshuffleB) + { + x_scale_t = x_scale.view({x_scale.size(1), x_scale.size(0)}).transpose(0, 1).contiguous(); + args.aq_ptr = x_scale_t.data_ptr(); + } + else + { + args.aq_ptr = x_scale.data_ptr(); + } } else { + // 8-warp RowMajor AQ: use x_scale directly, no transpose needed args.aq_ptr = x_scale.data_ptr(); } @@ -357,7 +388,7 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ const int stride_A = XQ.stride(0); const int stride_B = WQ.stride(0); const int stride_C = Y.stride(0); - const int stride_AQ = eight_waves ? M : static_cast(x_scale.stride(0)); + const int stride_AQ = aq_col_major ? M : static_cast(x_scale.stride(0)); const int stride_BQ = w_scale.stride(0); args.QK_A = AQK; diff --git a/op_tests/test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py b/op_tests/test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py new file mode 100644 index 0000000000..125050c09c --- /dev/null +++ b/op_tests/test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py @@ -0,0 +1,300 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +""" +Test that CKTile blockscale 8-warp kernels produce correct results with +both AQLayout options (ColumnMajor = default, RowMajor). + +Verifies: + 1. Both AQ layout variants produce output matching a PyTorch reference. + 2. TileKernelInstance name encodes AQRowMajor correctly. + 3. Padded weight stride handling remains correct for these kernels. +""" + +import argparse +import os +import sys + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from aiter.jit.utils.chip_info import get_gfx + +if get_gfx() != "gfx950": + print( + f"Skipping test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py: " + f"AQRowMajor only supported on gfx950, detected {get_gfx()}" + ) + sys.exit(0) + +import torch +import torch.nn.functional as F +from aiter import dtypes +from aiter.test_common import checkAllclose + +BLOCK_SHAPE = (128, 128) + + +def torch_reference(x, weight, x_scale, w_scale, dtype=dtypes.bf16): + """FP8 blockscale GEMM reference using PyTorch.""" + block_n, block_k = BLOCK_SHAPE + m, k = x.shape + n = weight.shape[0] + scale_n = (n + block_n - 1) // block_n + scale_k = (k + block_k - 1) // block_k + + x_f = x.to(x_scale.dtype).view(m, k // block_k, block_k) * x_scale.unsqueeze(-1) + x_f = x_f.view(m, k) + + from einops import rearrange + + ws = rearrange( + w_scale.view(-1, 1) + .repeat(1, block_n * block_k) + .view(scale_n, scale_k, block_n, block_k), + "bn bk n k -> (bn n) (bk k)", + )[:n, :k] + w_f = weight.to(ws.dtype) * ws + + return F.linear(x_f.to(dtypes.fp32), w_f.to(dtypes.fp32)).to(dtype) + + +def run_cktile_tune(x, weight, x_scale, w_scale, kernel_id, dtype=dtypes.bf16): + """Invoke a specific CKTile kernel by ID via the tune entry point.""" + from aiter.ops.gemm_op_a8w8 import gemm_a8w8_blockscale_cktile_tune + + m, k = x.shape + n = weight.shape[0] + Y = torch.empty(m, n, dtype=dtype, device=x.device) + return gemm_a8w8_blockscale_cktile_tune( + x, weight, x_scale, w_scale, Y, kernelId=kernel_id + ) + + +def test_instance_names(): + """Verify kernel name encoding includes 'aqrm' suffix for AQRowMajor.""" + this_dir = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, os.path.join(this_dir, "..", "csrc", "ck_gemm_a8w8_blockscale")) + from gemm_a8w8_blockscale_cktile_instance import ( + TileKernelInstance, + candidate_kernels_cktile_dict, + ) + + # Non-AQRowMajor name should NOT end with _aqrm (4x2x1 = 8 warps) + inst_default = TileKernelInstance( + 192, + 256, + 128, + 4, + 2, + 1, + 16, + 16, + 128, + "Intrawave", + False, + True, + False, + 1, + ) + assert not inst_default.name.endswith( + "_aqrm" + ), f"Default instance name should not have _aqrm suffix: {inst_default.name}" + + # AQRowMajor=True name SHOULD end with _aqrm + inst_rm = TileKernelInstance( + 192, + 256, + 128, + 4, + 2, + 1, + 16, + 16, + 128, + "Intrawave", + False, + True, + False, + 1, + AQRowMajor=True, + ) + assert inst_rm.name.endswith( + "_aqrm" + ), f"AQRowMajor instance name should have _aqrm suffix: {inst_rm.name}" + + # Names must be distinct + assert inst_default.name != inst_rm.name, "Names must differ" + + # Verify is_eight_warp property (4x2x1 = 8 warps, K_Warp_Tile=128) + assert inst_rm.is_eight_warp, "4x2x1 with K_Warp_Tile=128 should be 8-warp" + + non_8w = TileKernelInstance( + 16, + 128, + 256, + 1, + 4, + 1, + 16, + 16, + 64, + "Intrawave", + False, + True, + False, + 1, + ) + assert not non_8w.is_eight_warp, "1x4x1 with K_Warp_Tile=64 is not 8-warp" + + # Check that RowMajor variants exist in the candidate dict + aqrm_kernels = { + kid: k + for kid, k in candidate_kernels_cktile_dict.items() + if getattr(k, "AQRowMajor", False) + } + print(f" Found {len(aqrm_kernels)} AQRowMajor kernel variants in candidate dict") + assert len(aqrm_kernels) > 0, "Expected at least one AQRowMajor kernel variant" + + for kid, k in aqrm_kernels.items(): + assert k.is_eight_warp, ( + f"AQRowMajor kernel {kid} ({k.name}) should be 8-warp " + f"(warps={k.M_Warp}x{k.N_Warp}x{k.K_Warp}={k.M_Warp*k.N_Warp*k.K_Warp})" + ) + assert ( + "_aqrm" in k.name + ), f"AQRowMajor kernel {kid} should have _aqrm in name: {k.name}" + + print(" PASSED: instance name encoding") + + +def test_accuracy(m, n, k, dtype=dtypes.bf16, err_threshold=0.05): + """Test that both AQ layout variants match the PyTorch reference.""" + this_dir = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, os.path.join(this_dir, "..", "csrc", "ck_gemm_a8w8_blockscale")) + from gemm_a8w8_blockscale_cktile_instance import candidate_kernels_cktile_dict + + block_n, block_k = BLOCK_SHAPE + scale_m = m + scale_n = (n + block_n - 1) // block_n + scale_k = (k + block_k - 1) // block_k + + x = (torch.rand((m, k), dtype=dtypes.fp32, device="cuda") / 10).to(dtypes.fp8) + weight = (torch.rand((n, k), dtype=dtypes.fp32, device="cuda") / 10).to(dtypes.fp8) + x_scale = torch.rand([scale_m, scale_k], dtype=dtypes.fp32, device="cuda") + w_scale = torch.rand([scale_n, scale_k], dtype=dtypes.fp32, device="cuda") + + ref = torch_reference(x, weight, x_scale, w_scale, dtype) + + # Find 8-warp kernels: both ColumnMajor and RowMajor variants + eight_warp_colmajor = {} + eight_warp_rowmajor = {} + for kid, inst in candidate_kernels_cktile_dict.items(): + if not inst.is_eight_warp: + continue + if getattr(inst, "AQRowMajor", False): + eight_warp_rowmajor[kid] = inst + else: + eight_warp_colmajor[kid] = inst + + if not eight_warp_colmajor: + print(f" SKIP: no 8-warp ColumnMajor kernels for shape ({m},{n},{k})") + return + if not eight_warp_rowmajor: + print(f" SKIP: no 8-warp RowMajor kernels for shape ({m},{n},{k})") + return + + # Test one ColumnMajor and one RowMajor kernel + cm_kid, cm_inst = next(iter(eight_warp_colmajor.items())) + rm_kid, rm_inst = next(iter(eight_warp_rowmajor.items())) + + print(f" Testing ColumnMajor kernel {cm_kid} ({cm_inst.name})") + out_cm = run_cktile_tune(x, weight, x_scale, w_scale, cm_kid, dtype) + err_cm = checkAllclose(ref, out_cm, msg=f"ColMajor(id={cm_kid})") + + print(f" Testing RowMajor kernel {rm_kid} ({rm_inst.name})") + out_rm = run_cktile_tune(x, weight, x_scale, w_scale, rm_kid, dtype) + err_rm = checkAllclose(ref, out_rm, msg=f"RowMajor(id={rm_kid})") + + # Also check that both outputs are close to each other + checkAllclose(out_cm, out_rm, msg="ColMajor vs RowMajor") + + print(f" PASSED: shape ({m},{n},{k}) " f"cm_err={err_cm:.4f} rm_err={err_rm:.4f}") + + +def test_padded_weight_stride(m, n, k, dtype=dtypes.bf16): + """Test that RowMajor variant works with padded (non-contiguous) weight tensors, + similar to vLLM's _maybe_pad_fp8_weight.""" + this_dir = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, os.path.join(this_dir, "..", "csrc", "ck_gemm_a8w8_blockscale")) + from gemm_a8w8_blockscale_cktile_instance import candidate_kernels_cktile_dict + + block_n, block_k = BLOCK_SHAPE + scale_m = m + scale_n = (n + block_n - 1) // block_n + scale_k = (k + block_k - 1) // block_k + + x = (torch.rand((m, k), dtype=dtypes.fp32, device="cuda") / 10).to(dtypes.fp8) + weight_orig = (torch.rand((n, k), dtype=dtypes.fp32, device="cuda") / 10).to( + dtypes.fp8 + ) + x_scale = torch.rand([scale_m, scale_k], dtype=dtypes.fp32, device="cuda") + w_scale = torch.rand([scale_n, scale_k], dtype=dtypes.fp32, device="cuda") + + # Simulate _maybe_pad_fp8_weight: pad last dim, then narrow back + num_pad = 256 + weight_padded = F.pad(weight_orig, (0, num_pad), "constant", 0)[..., :-num_pad] + assert weight_padded.shape == weight_orig.shape + assert ( + weight_padded.stride(0) == k + num_pad + ), f"Expected stride {k + num_pad}, got {weight_padded.stride(0)}" + assert weight_padded.stride(-1) == 1 + + ref = torch_reference(x, weight_orig, x_scale, w_scale, dtype) + + # Find any RowMajor 8-warp kernel + rm_kid = None + for kid, inst in candidate_kernels_cktile_dict.items(): + if inst.is_eight_warp and getattr(inst, "AQRowMajor", False): + rm_kid = kid + break + + if rm_kid is None: + print(" SKIP: no RowMajor 8-warp kernel available") + return + + out = run_cktile_tune(x, weight_padded, x_scale, w_scale, rm_kid, dtype) + err = checkAllclose(ref, out, msg=f"PaddedWeight+RowMajor(id={rm_kid})") + print(f" PASSED: padded weight shape ({m},{n},{k}) err={err:.4f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Test CKTile blockscale GEMM AQRowMajor optimization" + ) + parser.add_argument( + "--skip-accuracy", + action="store_true", + help="Skip accuracy tests (useful for fast name-only checks)", + ) + args = parser.parse_args() + + print("=== Test 1: Instance name encoding ===") + test_instance_names() + + if not args.skip_accuracy: + shapes = [ + (128, 2048, 7168), + (256, 7168, 2048), + (1, 7168, 2048), + (512, 4096, 7168), + ] + print("\n=== Test 2: Accuracy (ColumnMajor vs RowMajor vs Reference) ===") + for m, n, k in shapes: + print(f"\nShape: M={m}, N={n}, K={k}") + test_accuracy(m, n, k) + + print("\n=== Test 3: Padded weight stride handling ===") + for m, n, k in shapes[:2]: + print(f"\nShape: M={m}, N={n}, K={k}") + test_padded_weight_stride(m, n, k) + + print("\n=== All tests passed ===")