From f4f95ec779fc415c79581b06f44a4721dfe1a6d6 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Fri, 24 Oct 2025 19:45:57 +0000 Subject: [PATCH 01/11] Kernel naming: add reusable constexpr repr helper --- .../ops/triton/_triton_kernels/gemm_a16w16.py | 19 +++++++- aiter/ops/triton/utils/_triton/kernel_repr.py | 44 +++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 aiter/ops/triton/utils/_triton/kernel_repr.py diff --git a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py index 33281106eb..b01023d5b7 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py @@ -6,6 +6,23 @@ from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr + + +_gemm_a16w16_repr = make_kernel_repr( + "_gemm_a16_w16_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "cache_modifier", + "activation", + "use_activation", + ], +) @triton.heuristics( @@ -16,7 +33,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_gemm_a16w16_repr) def _gemm_a16_w16_kernel( a_ptr, b_ptr, diff --git a/aiter/ops/triton/utils/_triton/kernel_repr.py b/aiter/ops/triton/utils/_triton/kernel_repr.py new file mode 100644 index 0000000000..71f66eec93 --- /dev/null +++ b/aiter/ops/triton/utils/_triton/kernel_repr.py @@ -0,0 +1,44 @@ +def _sanitize_constexpr_value(value): + if value is None: + return "NONE" + if isinstance(value, bool): + return str(int(value)) + if isinstance(value, int): + return str(value) + if isinstance(value, float): + if value.is_integer(): + return str(int(value)) + return str(value) + + # for lists, tuples, sets - recursively join each + if isinstance(value, (list, tuple, set)): + items = sorted(value, key=str) if isinstance(value, set) else value + sanitized_items = [_sanitize_constexpr_value(item) for item in items] + joined = "_".join(sanitized_items) + return joined if joined else "NONE" + + if isinstance(value, str): + cleaned_value = "".join(ch if ch.isalnum() else "_" for ch in value).strip("_") + return cleaned_value.upper() if cleaned_value else "NONE" + + cleaned_value = "".join(ch if ch.isalnum() else "_" for ch in str(value)).strip("_") + return cleaned_value.upper() if cleaned_value else "NONE" + + +def make_kernel_repr(base_name, config_keys): + def _repr(specialization): + constants = specialization.constants + name_parts = [] + + for key in config_keys: + value = constants.get(key, None) + symbol = _sanitize_constexpr_value(value) + name_parts.append(f"{key}_{symbol}") + + if not name_parts: + return base_name + + suffix = "_".join(name_parts) + return f"{base_name}_{suffix}" + + return _repr From c61b5ee33ce5db6ad46c90e5a6413cf243e99ed6 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Mon, 3 Nov 2025 15:52:18 +0000 Subject: [PATCH 02/11] add missing params to the repr --- aiter/ops/triton/_triton_kernels/gemm_a16w16.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py index b01023d5b7..3561a9985f 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py @@ -18,9 +18,13 @@ "GROUP_SIZE_M", "NUM_KSPLIT", "SPLITK_BLOCK_SIZE", + "EVEN_K", + "GRID_MN", "cache_modifier", "activation", "use_activation", + "ADD_BIAS", + "SKIP_REDUCE", ], ) From 42208d4d572673d0a77986f5e0d58f8eb249c1c2 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Tue, 28 Oct 2025 17:59:29 +0000 Subject: [PATCH 03/11] gemm kernels nomenclature changes --- .../ops/triton/_triton_kernels/gemm_a16w16.py | 16 +++++- .../_triton_kernels/gemm_a16w16_atomic.py | 19 ++++++- .../_triton_kernels/gemm_a16w16_gated.py | 19 ++++++- aiter/ops/triton/_triton_kernels/gemm_a8w8.py | 18 ++++++- .../_triton_kernels/gemm_a8w8_blockscale.py | 34 ++++++++++++- .../gemm_a8w8_per_token_scale.py | 32 +++++++++++- .../ops/triton/_triton_kernels/gemm_a8wfp4.py | 20 +++++++- .../triton/_triton_kernels/gemm_afp4wfp4.py | 50 +++++++++++++++++-- .../gemm_afp4wfp4_pre_quant_atomic.py | 19 ++++++- 9 files changed, 214 insertions(+), 13 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py index 3561a9985f..f5377e7eb3 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py @@ -29,6 +29,20 @@ ) +_gemm_a16w16_reduce_repr = make_kernel_repr( + "_gemm_a16w16_reduce_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "ACTUAL_KSPLIT", + "MAX_KSPLIT", + "activation", + "use_activation", + "ADD_BIAS", + ], +) + + @triton.heuristics( { "EVEN_K": lambda args: (args["K"] % (args["SPLITK_BLOCK_SIZE"]) == 0) @@ -169,7 +183,7 @@ def _gemm_a16_w16_kernel( tl.store(c_ptrs, c, mask=c_mask) -@triton.jit +@triton.jit(repr=_gemm_a16w16_reduce_repr) def _gemm_a16w16_reduce_kernel( bias_ptr, c_in_ptr, diff --git a/aiter/ops/triton/_triton_kernels/gemm_a16w16_atomic.py b/aiter/ops/triton/_triton_kernels/gemm_a16w16_atomic.py index 00963aba1c..d08a1b2760 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a16w16_atomic.py @@ -10,6 +10,23 @@ from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr + + +_gemm_a16w16_atomic_repr = make_kernel_repr( + "_gemm_a16_w16_atomic_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "cache_modifier", + "EVEN_K", + "GRID_MN", + ], +) @triton.heuristics( @@ -21,7 +38,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_gemm_a16w16_atomic_repr) def _gemm_a16_w16_atomic_kernel( a_ptr, b_ptr, diff --git a/aiter/ops/triton/_triton_kernels/gemm_a16w16_gated.py b/aiter/ops/triton/_triton_kernels/gemm_a16w16_gated.py index fc597aea75..4466053c67 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a16w16_gated.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a16w16_gated.py @@ -10,6 +10,23 @@ from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH from .activation import _get_activation_from_str +from ..utils._triton.kernel_repr import make_kernel_repr + + +_gemm_a16w16_gated_repr = make_kernel_repr( + "_gemm_a16_w16_gated_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "GRID_MN", + "cache_modifier", + "activation", + "use_activation", + ], +) @triton.heuristics( @@ -19,7 +36,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_gemm_a16w16_gated_repr) def _gemm_a16_w16_gated_kernel( a_ptr, b_ptr, diff --git a/aiter/ops/triton/_triton_kernels/gemm_a8w8.py b/aiter/ops/triton/_triton_kernels/gemm_a8w8.py index 57e5bf1dbc..1755c41d40 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a8w8.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a8w8.py @@ -8,6 +8,22 @@ import triton.language as tl from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr + + +_gemm_a8w8_repr = make_kernel_repr( + "_gemm_a8w8_kernel", + [ + "HAS_BIAS", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "GRID_MN", + "NUM_XCDS", + ], +) @triton.heuristics( @@ -17,7 +33,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_gemm_a8w8_repr) def _gemm_a8w8_kernel( # Pointers to matrices a_ptr, diff --git a/aiter/ops/triton/_triton_kernels/gemm_a8w8_blockscale.py b/aiter/ops/triton/_triton_kernels/gemm_a8w8_blockscale.py index 9343d40787..faf545579b 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a8w8_blockscale.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a8w8_blockscale.py @@ -11,6 +11,36 @@ from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr + + +_gemm_a8w8_blockscale_repr = make_kernel_repr( + "_gemm_a8w8_blockscale_kernel", + [ + "GROUP_K", + "GROUP_N", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "EVEN_K", + "GRID_MN", + "cache_modifier", + ], +) + + +_gemm_a8w8_blockscale_reduce_repr = make_kernel_repr( + "_gemm_a8w8_blockscale_reduce_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "ACTUAL_KSPLIT", + "MAX_KSPLIT", + ], +) @triton.heuristics( @@ -20,7 +50,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_gemm_a8w8_blockscale_repr) def _gemm_a8w8_blockscale_kernel( # Pointers to matrices a_ptr, @@ -195,7 +225,7 @@ def _gemm_a8w8_blockscale_kernel( tl.store(c_ptrs, c, mask=c_mask) -@triton.jit +@triton.jit(repr=_gemm_a8w8_blockscale_reduce_repr) def _gemm_a8w8_blockscale_reduce_kernel( c_in_ptr, c_out_ptr, diff --git a/aiter/ops/triton/_triton_kernels/gemm_a8w8_per_token_scale.py b/aiter/ops/triton/_triton_kernels/gemm_a8w8_per_token_scale.py index ed5ef5a601..32c4ccee91 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a8w8_per_token_scale.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a8w8_per_token_scale.py @@ -9,6 +9,34 @@ from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr + + +_gemm_a8w8_per_token_scale_repr = make_kernel_repr( + "_gemm_a8w8_per_token_scale_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "EVEN_K", + "GRID_MN", + "cache_modifier", + ], +) + + +_gemm_a8w8_per_token_scale_reduce_repr = make_kernel_repr( + "_gemm_a8w8_per_token_scale_reduce_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "ACTUAL_KSPLIT", + "MAX_KSPLIT", + ], +) @triton.heuristics( @@ -18,7 +46,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_gemm_a8w8_per_token_scale_repr) def _gemm_a8w8_per_token_scale_kernel( # Pointers to matrices a_ptr, @@ -167,7 +195,7 @@ def _gemm_a8w8_per_token_scale_kernel( tl.store(c_ptrs, c, mask=c_mask) -@triton.jit +@triton.jit(repr=_gemm_a8w8_per_token_scale_reduce_repr) def _gemm_a8w8_per_token_scale_reduce_kernel( c_in_ptr, c_out_ptr, diff --git a/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py index 2d5c7b9202..2ad97d5b23 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py @@ -8,6 +8,24 @@ from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr + + +_gemm_a8wfp4_repr = make_kernel_repr( + "_gemm_a8wfp4_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "EVEN_K", + "GRID_MN", + "RAW_MASKED_LOADS", + "cache_modifier", + ], +) @triton.heuristics( @@ -19,7 +37,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_gemm_a8wfp4_repr) def _gemm_a8wfp4_kernel( a_ptr, b_ptr, diff --git a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py index 918474fdcd..ac9a2c1a69 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py @@ -9,6 +9,50 @@ from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr + + +_gemm_afp4wfp4_repr = make_kernel_repr( + "_gemm_afp4_wfp4_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "EVEN_K", + "GRID_MN", + "cache_modifier", + ], +) + + +_gemm_afp4wfp4_preshuffled_repr = make_kernel_repr( + "_gemm_afp4_wfp4_kernel_preshuffled_scales", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "EVEN_K", + "GRID_MN", + "cache_modifier", + ], +) + + +_gemm_afp4wfp4_reduce_repr = make_kernel_repr( + "_gemm_afp4_wfp4_reduce_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "ACTUAL_KSPLIT", + "MAX_KSPLIT", + ], +) @triton.heuristics( @@ -18,7 +62,7 @@ and (args["K"] % (args["SPLITK_BLOCK_SIZE"] // 2) == 0), } ) -@triton.jit +@triton.jit(repr=_gemm_afp4wfp4_repr) def _gemm_afp4_wfp4_kernel( a_ptr, b_ptr, @@ -173,7 +217,7 @@ def _gemm_afp4_wfp4_kernel( and (args["K"] % (args["SPLITK_BLOCK_SIZE"] // 2) == 0), } ) -@triton.jit +@triton.jit(repr=_gemm_afp4wfp4_preshuffled_repr) def _gemm_afp4_wfp4_kernel_preshuffled_scales( a_ptr, b_ptr, @@ -585,7 +629,7 @@ def _gemm_afp4_wfp4_kernel_preshuffled_weight_scales( tl.store(c_ptrs, c, mask=c_mask, cache_modifier=".wt") -@triton.jit +@triton.jit(repr=_gemm_afp4wfp4_reduce_repr) def _gemm_afp4_wfp4_reduce_kernel( c_in_ptr, c_out_ptr, diff --git a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4_pre_quant_atomic.py b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4_pre_quant_atomic.py index bc4edd2a4b..1456f2a1a3 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4_pre_quant_atomic.py +++ b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4_pre_quant_atomic.py @@ -12,6 +12,23 @@ from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH from .quant import _mxfp4_quant_op +from ..utils._triton.kernel_repr import make_kernel_repr + + +_gemm_afp4wfp4_pre_quant_repr = make_kernel_repr( + "_gemm_afp4_wfp4_pre_quant_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "EVEN_K", + "GRID_MN", + "cache_modifier", + ], +) @triton.heuristics( @@ -23,7 +40,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_gemm_afp4wfp4_pre_quant_repr) def _gemm_afp4_wfp4_pre_quant_kernel( a_ptr, b_ptr, From 1092f775d4368227e7a61ac45cfe54f8786b7669 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Tue, 28 Oct 2025 19:10:43 +0000 Subject: [PATCH 04/11] add gemm_afp4_wfp4_reduce_repr too --- aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py index 2ad97d5b23..2afcf50904 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py @@ -27,6 +27,16 @@ ], ) +_gemm_afp4_wfp4_reduce_repr = make_kernel_repr( + "_gemm_afp4_wfp4_reduce_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "ACTUAL_KSPLIT", + "MAX_KSPLIT", + ], +) + @triton.heuristics( { @@ -201,7 +211,7 @@ def _gemm_a8wfp4_kernel( tl.store(c_ptrs, c, mask=c_mask) -@triton.jit +@triton.jit(repr=_gemm_afp4_wfp4_reduce_repr) def _gemm_afp4_wfp4_reduce_kernel( c_in_ptr, c_out_ptr, @@ -315,8 +325,6 @@ def _get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int): else: break - SPLITK_BLOCK_SIZE = ( - triton.cdiv((2 * triton.cdiv(K, NUM_KSPLIT)), BLOCK_SIZE_K) * BLOCK_SIZE_K - ) + SPLITK_BLOCK_SIZE = triton.cdiv((2 * triton.cdiv(K, NUM_KSPLIT)), BLOCK_SIZE_K) * BLOCK_SIZE_K return SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT From 0809355a352051a50be469fb7d64650583544075 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Tue, 28 Oct 2025 15:14:38 -0400 Subject: [PATCH 05/11] Update gemm_a8wfp4.py --- aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py index 2afcf50904..8196c5cf49 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py @@ -325,6 +325,8 @@ def _get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int): else: break - SPLITK_BLOCK_SIZE = triton.cdiv((2 * triton.cdiv(K, NUM_KSPLIT)), BLOCK_SIZE_K) * BLOCK_SIZE_K + SPLITK_BLOCK_SIZE = ( + triton.cdiv((2 * triton.cdiv(K, NUM_KSPLIT)), BLOCK_SIZE_K) * BLOCK_SIZE_K + ) return SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT From 8a3362ef148ca9f9242fa9b3bf4e3121210f4710 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Mon, 3 Nov 2025 19:49:02 +0000 Subject: [PATCH 06/11] remove unused parameters --- aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py index ac9a2c1a69..a447ed6c9a 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py @@ -22,7 +22,6 @@ "NUM_KSPLIT", "SPLITK_BLOCK_SIZE", "EVEN_K", - "GRID_MN", "cache_modifier", ], ) @@ -38,7 +37,6 @@ "NUM_KSPLIT", "SPLITK_BLOCK_SIZE", "EVEN_K", - "GRID_MN", "cache_modifier", ], ) From 2daf0143f6bf3f5a2c20cdc046eb2b4304f500b3 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Mon, 3 Nov 2025 20:53:54 +0000 Subject: [PATCH 07/11] Add repr to _gemm_afp4_wfp4_kernel_preshuffled_weight_scales --- .../ops/triton/_triton_kernels/gemm_afp4wfp4.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py index a447ed6c9a..a04764866b 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py @@ -53,6 +53,21 @@ ) +_gemm_afp4wfp4_preshuffled_weight_scales_repr = make_kernel_repr( + "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "EVEN_K", + "cache_modifier", + ], +) + + @triton.heuristics( { "EVEN_K": lambda args: (args["K"] % (args["BLOCK_SIZE_K"] // 2) == 0) @@ -419,7 +434,7 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales( and (args["K"] % (args["SPLITK_BLOCK_SIZE"] // 2) == 0), } ) -@triton.jit +@triton.jit(repr=_gemm_afp4wfp4_preshuffled_weight_scales_repr) def _gemm_afp4_wfp4_kernel_preshuffled_weight_scales( a_ptr, b_ptr, From 58953c79392590a951605dad313bdfcb4502221d Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Tue, 4 Nov 2025 18:25:56 -0500 Subject: [PATCH 08/11] [TRITON] Kernel naming: add reusable constexpr repr helper (#1260) * Kernel naming: add reusable constexpr repr helper for gemm a16w16 * add missing params to the repr --- 3rdparty/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 7c6430eca0..32773fe5cb 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 7c6430eca04e62454217630ae2a0bbd70ff50a00 +Subproject commit 32773fe5cb176efd2fcbb361f183164fc6525d8a From b71d49e6d5e2eef7c3bedf1860caffc8c7d95463 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Wed, 5 Nov 2025 16:36:37 +0000 Subject: [PATCH 09/11] Add missing API documentation --- .../ops/triton/_triton_kernels/gemm_a8wfp4.py | 13 ++- .../triton/_triton_kernels/gemm_afp4wfp4.py | 30 +++---- .../gemm_afp4wfp4_pre_quant_atomic.py | 3 +- aiter/ops/triton/gemm_a16w16.py | 27 +++--- aiter/ops/triton/gemm_a16w16_atomic.py | 20 +++-- aiter/ops/triton/gemm_a16w16_gated.py | 22 ++--- aiter/ops/triton/gemm_a8w8.py | 27 +++--- aiter/ops/triton/gemm_a8w8_blockscale.py | 26 +++--- aiter/ops/triton/gemm_a8w8_per_token_scale.py | 20 +++-- aiter/ops/triton/gemm_a8wfp4.py | 40 +++++---- aiter/ops/triton/gemm_afp4wfp4.py | 82 +++++++++++-------- .../triton/gemm_afp4wfp4_pre_quant_atomic.py | 24 +++--- 12 files changed, 188 insertions(+), 146 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py index 8196c5cf49..9db620cda0 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py @@ -32,8 +32,14 @@ [ "BLOCK_SIZE_M", "BLOCK_SIZE_N", - "ACTUAL_KSPLIT", - "MAX_KSPLIT", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "EVEN_K", + "GRID_MN", + "RAW_MASKED_LOADS", + "cache_modifier", ], ) @@ -80,7 +86,8 @@ def _gemm_a8wfp4_kernel( RAW_MASKED_LOADS: tl.constexpr, cache_modifier: tl.constexpr, ): - """Kernel for computing the matmul C = A x B. + """ + Kernel for computing the matmul C = A x B. A is in fp8 e4m3 format. B is in the microscale fp4 (mxfp4) format. A_scales and B_scales are in e8m0 format. diff --git a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py index a04764866b..d9f6c84ac5 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py @@ -42,17 +42,6 @@ ) -_gemm_afp4wfp4_reduce_repr = make_kernel_repr( - "_gemm_afp4_wfp4_reduce_kernel", - [ - "BLOCK_SIZE_M", - "BLOCK_SIZE_N", - "ACTUAL_KSPLIT", - "MAX_KSPLIT", - ], -) - - _gemm_afp4wfp4_preshuffled_weight_scales_repr = make_kernel_repr( "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales", [ @@ -68,6 +57,16 @@ ) +_gemm_afp4wfp4_reduce_repr = make_kernel_repr( + "_gemm_afp4_wfp4_reduce_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "ACTUAL_KSPLIT", + "MAX_KSPLIT", + ], +) + @triton.heuristics( { "EVEN_K": lambda args: (args["K"] % (args["BLOCK_SIZE_K"] // 2) == 0) @@ -106,7 +105,8 @@ def _gemm_afp4_wfp4_kernel( EVEN_K: tl.constexpr, cache_modifier: tl.constexpr, ): - """Kernel for computing the matmul C = A x B. + """ + Kernel for computing the matmul C = A x B. A and B inputs are in the microscale fp4 (mxfp4) format. A_scales and B_scales are in e8m0 format. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -261,7 +261,8 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales( EVEN_K: tl.constexpr, cache_modifier: tl.constexpr, ): - """Kernel for computing the matmul C = A x B. + """ + Kernel for computing the matmul C = A x B. A and B inputs are in the microscale fp4 (mxfp4) format. A_scales and B_scales are in e8m0 format. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -465,7 +466,8 @@ def _gemm_afp4_wfp4_kernel_preshuffled_weight_scales( EVEN_K: tl.constexpr, cache_modifier: tl.constexpr, ): - """Kernel for computing the matmul C = A x B. + """ + Kernel for computing the matmul C = A x B. A and B inputs are in the microscale fp4 (mxfp4) format. A_scales and B_scales are in e8m0 format. A has shape (M, K), B has shape (K, N) and C has shape (M, N) diff --git a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4_pre_quant_atomic.py b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4_pre_quant_atomic.py index 1456f2a1a3..0d27d412c6 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4_pre_quant_atomic.py +++ b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4_pre_quant_atomic.py @@ -69,7 +69,8 @@ def _gemm_afp4_wfp4_pre_quant_kernel( GRID_MN: tl.constexpr, cache_modifier: tl.constexpr, ): - """Kernel for computing the matmul C = A x B. + """ + Kernel for computing the matmul C = A x B. A and B inputs are in the microscale fp4 (mxfp4) format. A_scales and B_scales are in e8m0 format. A has shape (M, K), B has shape (K, N) and C has shape (M, N) diff --git a/aiter/ops/triton/gemm_a16w16.py b/aiter/ops/triton/gemm_a16w16.py index b3d4f00bd3..549ddd36d8 100644 --- a/aiter/ops/triton/gemm_a16w16.py +++ b/aiter/ops/triton/gemm_a16w16.py @@ -27,19 +27,24 @@ def gemm_a16w16( skip_reduce: Optional[bool] = False, ): """ - Computes the 16 bit matmul Y = X x W - - Key parameters: - - X: Matrix X with shape (M, K). - - W: Matrix W with shape (N, K). - - dtype: Optional parameter to specifcy bf16 or fp16 datatype. Default is bf16 - - Y: Output Matrix Y with shape (M, N). - If this is none, then it's created by this API and returned as output. - - activation: Optional activation function to apply to the output. - One of ("gelu", "gelu_tanh", "silu", "silu_exp2", "relu"). Default is None. + Computes 16 bit matrix multiplication Y = X @ W^T + + Args: + x (torch.Tensor): Input matrix with shape (M, K). + w (torch.Tensor): Weight matrix with shape (N, K), internally transposed. + bias (Optional[torch.Tensor]): Bias vector with shape (N,). + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). + activation (Optional[str]): Activation function ("gelu", "gelu_tanh", "silu", + "silu_exp2", "relu"). + skip_reduce (Optional[bool]): Skip reduction of split-K partial results. + Enables kernel fusion with downstream operations (FP8/FP4 quantization, + RMSNorm). Returns shape (NUM_KSPLIT, M, N) instead of (M, N). Returns: - - Y: The output matrix with shape (M, N). + torch.Tensor: Output with shape (M, N) or (NUM_KSPLIT, M, N) if skip_reduce=True. """ _LOGGER.info(f"GEMM_A16W16: x={tuple(x.shape)} w={tuple(w.shape)}") diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index bbb5c5c63f..78026c80f0 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -23,16 +23,20 @@ def gemm_a16w16_atomic( config: Optional[dict] = None, ): """ - Computes the 16 bit matmul Y = X x W - NOTE: If dtype is set to bf16, aggregation in bf16 with atomic_add will lead to slight precision loss. - Key parameters: - - X: Matrix X with shape (M, K). - - W: Matrix W with shape (N, K). - - dtype: Optional parameter to specifcy bf16 or fp16 datatype. Default is bf16 - - Y: Output Matrix Y with shape (M, N). If this is none, then it's created by this API and returned as output + Computes 16 bit matrix multiplication Y = X @ W^T using atomic operations for split-K reduction. + + Args: + x (torch.Tensor): Input matrix with shape (M, K). + w (torch.Tensor): Weight matrix with shape (N, K), internally transposed. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + Note: BF16 atomic aggregation may have slight precision loss. + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + Must be zero-initialized for split-K (NUM_KSPLIT > 1). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, cache_modifier). Returns: - - Y: The output matrix with shape (M, N). + torch.Tensor: Output with shape (M, N). """ _LOGGER.info( diff --git a/aiter/ops/triton/gemm_a16w16_gated.py b/aiter/ops/triton/gemm_a16w16_gated.py index 33fcc13abf..8871daebbc 100644 --- a/aiter/ops/triton/gemm_a16w16_gated.py +++ b/aiter/ops/triton/gemm_a16w16_gated.py @@ -24,19 +24,21 @@ def gemm_a16w16_gated( activation: Optional[str] = None, ): """ - Computes the 16 bit matmul Y = X x W - Uses the first half of the output (along the N dim) as a gate for the second half (e.g for SwiGLU) + Computes 16 bit gated matrix multiplication Y = X @ W^T with gating mechanism (e.g., SwiGLU). + Uses first half of W output as gate for second half, producing (M, N//2) output. - Key parameters: - - X: Matrix X with shape (M, K). - - W: Matrix W with shape (N, K). - - dtype: Optional parameter to specifcy bf16 or fp16 datatype. Default is bf16 - - Y: Output Matrix Y with shape (M, N//2). - If this is none, then it's created by this API and returned as output. - - activation: Optional activation function to apply to the output. One of ("gelu", "gelu_tanh", "silu", "silu_exp2", "relu") + Args: + x (torch.Tensor): Input matrix with shape (M, K). + w (torch.Tensor): Weight matrix with shape (N, K), internally transposed. N must be even. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N//2). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M). + activation (Optional[str]): Activation function applied to gate ("gelu", "gelu_tanh", + "silu", "silu_exp2", "relu"). Returns: - - Y: The output matrix with shape (M, N//2). + torch.Tensor: Gated output with shape (M, N//2). """ _LOGGER.info(f"GEMM_A16W16_GATED: x={tuple(x.shape)} w={tuple(w.shape)}") diff --git a/aiter/ops/triton/gemm_a8w8.py b/aiter/ops/triton/gemm_a8w8.py index 66c14e4470..3602ef2ff6 100644 --- a/aiter/ops/triton/gemm_a8w8.py +++ b/aiter/ops/triton/gemm_a8w8.py @@ -27,21 +27,22 @@ def gemm_a8w8( config: Optional[dict] = None, ): """ - Computes the 8 bit matmul Y = X x WT, applies a conversion scale and optionally adds a bias - to the result. - The conversion scale is received in the form of two 1D tensors that are multiplied to form a - 2D one before being applied. - - Key parameters: - - X: Matrix X with shape (M, K). - - W: Matrix W with shape (N, K). - - X_scale: First scale tensor with shape (M, 1). - - W_scale: Second scale tensor with shape (1, N). - - Bias: Bias tensor with shape (1, N). - - Y: Output Matrix Y with shape (M, K). If this is none, then it's created by this API and returned as output + Computes 8 bit matrix multiplication Y = (X @ W^T) * (x_scale * w_scale) with optional bias. + INT8 inputs are scaled back to higher precision using per-tensor scale factors. + + Args: + x (torch.Tensor): INT8 input matrix with shape (M, K). + w (torch.Tensor): INT8 weight matrix with shape (N, K), internally transposed. + x_scale (torch.Tensor): Scale factor for x with shape (M, 1) or (M,). + w_scale (torch.Tensor): Scale factor for w with shape (1, N) or (N,). + bias (Optional[torch.Tensor]): Bias vector with shape (N,). + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M). Returns: - - Y: The output matrix with shape (M, N). + torch.Tensor: Output with shape (M, N) in higher precision format. """ _LOGGER.info( diff --git a/aiter/ops/triton/gemm_a8w8_blockscale.py b/aiter/ops/triton/gemm_a8w8_blockscale.py index 6ec327059d..6609c26890 100644 --- a/aiter/ops/triton/gemm_a8w8_blockscale.py +++ b/aiter/ops/triton/gemm_a8w8_blockscale.py @@ -27,19 +27,23 @@ def gemm_a8w8_blockscale( config: Optional[dict] = None, ): """ - Computes the 8 bit matmul Y = X x WT using the block-scale quantization approach. - - Key parameters: - - X: Matrix X with shape (M, K). - - W: Matrix W with shape (N, K). - - X_scale: Scale tensor for X with shape (M, *scale_k). - - W_scale: Scale tensor for W with shape (**scale_n, *scale_k). + Computes 8 bit matrix multiplication Y = X @ W^T using block-wise quantization scales. + Each block along K and N dimensions has independent scale factors for fine-grained quantization. + + Args: + x (torch.Tensor): INT8 input matrix with shape (M, K). + w (torch.Tensor): INT8 weight matrix with shape (N, K), internally transposed. + x_scale (torch.Tensor): Block-wise scale for x with shape (M, scale_k). + scale_k = ceil(K / scale_block_size_k). + w_scale (torch.Tensor): Block-wise scale for w with shape (scale_n, scale_k). + scale_n = ceil(N / scale_block_size_n). + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT). Returns: - - Y: The output matrix with shape (M, N). - - *scale_k = (K + scale_block_size_k - 1) // scale_block_size_k -> ceil_div(K, scale_block_size_k) - **scale_n = (N + scale_block_size_n - 1) // scale_block_size_n -> ceil_div(N, scale_block_size_n) + torch.Tensor: Output with shape (M, N). """ _LOGGER.info( f"GEMM_A8W8_BLOCKSCALE: x={tuple(x.shape)} w={tuple(w.shape)} x_scale={tuple(x_scale.shape)} w_scale={tuple(w_scale.shape)}" diff --git a/aiter/ops/triton/gemm_a8w8_per_token_scale.py b/aiter/ops/triton/gemm_a8w8_per_token_scale.py index ec3c45a37f..e8032bdbeb 100644 --- a/aiter/ops/triton/gemm_a8w8_per_token_scale.py +++ b/aiter/ops/triton/gemm_a8w8_per_token_scale.py @@ -24,17 +24,21 @@ def gemm_a8w8_per_token_scale( config=None, ): """ - Computes the 8 bit matmul Y = X x WT using the block-scale quantization approach. + Computes 8 bit matrix multiplication Y = X @ W^T using per-token quantization scales. + Each token (row) in x and each output column in w has independent scale factors. - Key parameters: - - X: Matrix X with shape (M, K). - - W: Matrix W with shape (N, K). - - X_scale: Scale tensor for X with shape (M, 1). - - W_scale: Scale tensor for W with shape (N, 1). - - Y: Output Matrix Y with shape (M, K). If this is none, then it's created by this API and returned as output + Args: + x (torch.Tensor): INT8 input matrix with shape (M, K). + w (torch.Tensor): INT8 weight matrix with shape (N, K), internally transposed. + x_scale (torch.Tensor): Per-token scale for x with shape (M, 1) or (M,). + w_scale (torch.Tensor): Per-output-channel scale for w with shape (N, 1) or (N,). + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT). Returns: - - Y: The output matrix with shape (M, N). + torch.Tensor: Output with shape (M, N). """ M, K = x.shape N, K = w.shape diff --git a/aiter/ops/triton/gemm_a8wfp4.py b/aiter/ops/triton/gemm_a8wfp4.py index 1870aca172..ffa4b7b6d8 100644 --- a/aiter/ops/triton/gemm_a8wfp4.py +++ b/aiter/ops/triton/gemm_a8wfp4.py @@ -34,29 +34,27 @@ def gemm_a8wfp4( config: Optional[dict] = None, ): """ - Computes the matmul Y = X @ W.T (where W.T is the logical transpose of unpacked W) - - X is in fp8 e4m3 format. - W is in packed microscale fp4 (mxfp4) format, where 2 fp4 values are packed per uint8. - x_scales are in fp32 format (one scale per row of X). - w_scales are in e8m0 format (one scale per group of 32 elements in K dimension). - - Key parameters: - - x: Matrix X with shape (M, K) in fp8 e4m3 format - - w: Matrix W with shape (N, K//2) in packed fp4 format (2 values per uint8) - - y: Pre-allocated output matrix with shape (M, N) - - x_scales: Per-row scales for X with shape (M, 1) in fp32 format - - w_scales: Per-group scales for W with shape (N, K//32) in e8m0 format - - dtype: Output data type (default: torch.bfloat16) + Computes matrix multiplication Y = X @ W^T with FP8 activations and FP4 weights. + + Args: + x (torch.Tensor): FP8 E4M3 input matrix with shape (M, K). + w (torch.Tensor): Packed FP4 weight matrix with shape (N, K//2), internally transposed. + Each uint8 contains 2 FP4 values. + y (torch.Tensor): Pre-allocated output tensor with shape (M, N). + x_scales (torch.Tensor): FP32 per-row scale for x with shape (M, 1). + w_scales (torch.Tensor): E8M0 per-group scale for w with shape (N, K//32). + One scale per 32 elements in K dimension. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). + + Note: + - The logical shape of W after unpacking would be (N, K) + - Every 32 consecutive elements in the K dimension of W share + one E8M0 scale Returns: - - y: The output matrix with shape (M, N) containing X @ W.T - - Note: - - W is stored in packed format where each uint8 contains 2 fp4 values - - The logical shape of W after unpacking would be (N, K) - - Every 32 consecutive elements in the K dimension of W share one e8m0 scale - - X uses per-row scaling (not per-group scaling) + torch.Tensor: Output with shape (M, N). """ _LOGGER.info( f"GEMM_A8FP4: x={tuple(x.shape)} w={tuple(w.shape)} x_scale={tuple(x_scales.shape)} w_scale={tuple(w_scales.shape)} " diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index ef5b97c615..a5353b9051 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -73,20 +73,22 @@ def gemm_afp4wfp4( config: Optional[dict] = None, ): """ - Computes the matmul Y = X x W - X and W are e2m1 fp4 tensors. - x_scales and w_scales are e8m0 tensors. - Every 32 elements in the K dimension share one e8m0 scale. - - - Key parameters: - - X: Matrix X with shape (M, K). - - W: Matrix W with shape (N, K). - - X_scales: Matrix with shape (M, K // 32) - - W_scales: Matrix with shape (N, K // 32) + Computes matrix multiplication Y = X @ W^T with FP4 activations and FP4 weights. + + Args: + x (torch.Tensor): FP4 E2M1 input matrix with shape (M, K). + w (torch.Tensor): FP4 E2M1 weight matrix with shape (N, K), internally transposed. + x_scales (torch.Tensor): E8M0 per-group scale for x with shape (M, K//32). + One scale per 32 elements in K dimension. + w_scales (torch.Tensor): E8M0 per-group scale for w with shape (N, K//32). + One scale per 32 elements in K dimension. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). Returns: - - Y: The output matrix with shape (M, N). + torch.Tensor: Output with shape (M, N). """ _LOGGER.info( @@ -200,20 +202,23 @@ def gemm_afp4wfp4_preshuffled_scales( config: Optional[dict] = None, ): """ - Computes the matmul Y = X x W - X and W are e2m1 fp4 tensors. - x_scales and w_scales are e8m0 tensors. - Every 32 elements in the K dimension share one e8m0 scale. - - - Key parameters: - - X: Matrix X with shape (M, K). M >= 32 is required - - W: Matrix W with shape (N, K). - - X_scales: Matrix with shape (M // 32, K) - - W_scales: Matrix with shape (N // 32, K) + Computes matrix multiplication Y = X @ W^T with FP4 activations and FP4 weights using preshuffled scales. + Scales are arranged with M/N dimension grouped by 32 instead of K dimension. + + Args: + x (torch.Tensor): FP4 E2M1 input matrix with shape (M, K). M >= 32 required. + w (torch.Tensor): FP4 E2M1 weight matrix with shape (N, K), internally transposed. + x_scales (torch.Tensor): E8M0 per-group scale for x with shape (M//32, K). + Groups of 32 rows in M dimension share K scales. + w_scales (torch.Tensor): E8M0 per-group scale for w with shape (N//32, K). + Groups of 32 rows in N dimension share K scales. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). Returns: - - Y: The output matrix with shape (M, N). + torch.Tensor: Output with shape (M, N). """ assert arch_info.is_fp4_avail(), "MXFP4 is not available on your device" @@ -332,20 +337,25 @@ def gemm_afp4wfp4_preshuffled_weight_scales( use_aot: Optional[bool] = True, ): """ - Computes the matmul Y = X x W - X and W are e2m1 fp4 tensors. - x_scales and w_scales are e8m0 tensors. - Every 32 elements in the K dimension share one e8m0 scale. - - - Key parameters: - - X: Matrix X with shape (M, K). - - W: Matrix W with shape (N, K). - - X_scales: Matrix with shape (M // 32, K) - - W_scales: Matrix with shape (N // 32, K) + Computes matrix multiplication Y = X @ W^T with FP4 activations and FP4 weights using preshuffled weight scales. + Weight matrix and scales are stored in optimized layout for improved performance. + + Args: + x (torch.Tensor): FP4 E2M1 input matrix with shape (M, K). + w (torch.Tensor): FP4 E2M1 weight matrix with shape (N//16, K*16), internally transposed. + Preshuffled layout: logical shape after unpacking is (N, K). + x_scales (torch.Tensor): E8M0 per-group scale for x with shape (M//32, K) if M >= 32, + or (M, K//32) if M < 32. + w_scales (torch.Tensor): E8M0 per-group scale for w with shape (N//32, K). + Groups of 32 rows in N dimension share K scales. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). + use_aot (Optional[bool]): Enable ahead-of-time compilation metadata. Returns: - - Y: The output matrix with shape (M, N). + torch.Tensor: Output with shape (M, N). """ assert arch_info.is_fp4_avail(), "MXFP4 is not available on your device" diff --git a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py index 933b2c3768..94369cc2c8 100644 --- a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py +++ b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py @@ -25,19 +25,23 @@ def gemm_afp4wfp4_pre_quant( config: Optional[dict] = None, ): """ - Computes the matmul Y = X x W - W is an e2m1 fp4 tensor and w_scales is an e8m0 tensor. - Every 32 elements in the K dimension share one e8m0 scale. - X gets quantized to the microscale fp4 (mxfp4) format before the GEMM. + Computes matrix multiplication Y = X @ W^T with on-the-fly FP4 quantization of activations. + X is quantized to MXFP4 during computation, W is pre-quantized FP4. Uses atomic operations for split-K reduction. - - Key parameters: - - X: Matrix X with shape (M, K). - - W: Matrix W with shape (N, K). - - W_scales: Matrix with shape (N, K // 32) + Args: + x (torch.Tensor): Higher precision input matrix with shape (M, K) (BF16 or FP16). + Quantized to FP4 E2M1 on-the-fly during GEMM. + w (torch.Tensor): FP4 E2M1 weight matrix with shape (N, K), internally transposed. + w_scales (torch.Tensor): E8M0 per-group scale for w with shape (N, K//32). + One scale per 32 elements in K dimension. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + Must be zero-initialized for atomic operations. + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT). Returns: - - Y: The output matrix with shape (M, N). + torch.Tensor: Output with shape (M, N). """ _LOGGER.info( From 6b2faf38632441fe29aceecb3d0024f5c63ec153 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Wed, 5 Nov 2025 16:42:50 +0000 Subject: [PATCH 10/11] fix formatting --- aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py index d9f6c84ac5..514f00cab6 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py @@ -67,6 +67,7 @@ ], ) + @triton.heuristics( { "EVEN_K": lambda args: (args["K"] % (args["BLOCK_SIZE_K"] // 2) == 0) From 010c21d3e5eaa3414e3b8d71431c3e140d9eca02 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Wed, 5 Nov 2025 19:08:17 +0000 Subject: [PATCH 11/11] Revert composable_kernel submodule to match main --- 3rdparty/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 32773fe5cb..7c6430eca0 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 32773fe5cb176efd2fcbb361f183164fc6525d8a +Subproject commit 7c6430eca04e62454217630ae2a0bbd70ff50a00