From d0c445da26f1f7dc3cf6b399ba6165d3cbfa8a8a Mon Sep 17 00:00:00 2001 From: "Lin, Soga" Date: Wed, 21 May 2025 06:01:27 +0000 Subject: [PATCH 01/15] Add gfx950 support in asm_moe --- aiter/fused_moe_bf16_asm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aiter/fused_moe_bf16_asm.py b/aiter/fused_moe_bf16_asm.py index fd310590af..c381221327 100755 --- a/aiter/fused_moe_bf16_asm.py +++ b/aiter/fused_moe_bf16_asm.py @@ -140,8 +140,8 @@ def asm_moe( 128, ), "asm_moe for block_scale only support (128, 128)" assert ( - w1.dtype == torch.float8_e4m3fnuz - ), "asm_moe for block_scale only support float8_e4m3fnuz weight" + w1.dtype == dtypes.fp8 + ), "asm_moe for block_scale only support float8_e4m3fnuz weight on gfx942 and float8_e4m3fn on gfx950" assert ( w2.shape[2] * 2 == w1.shape[1] ), "aiter moe for block_scale only support g1u1" @@ -150,7 +150,7 @@ def asm_moe( a1_q, a1_scale = pertoken_quant( hidden_states.view(-1, model_dim // scale_blk_k, scale_blk_k), - quant_dtype=torch.float8_e4m3fnuz, + quant_dtype=dtypes.fp8, ) a1_q = a1_q.view(-1, model_dim) a1_scale = a1_scale.squeeze(-1).t().contiguous() @@ -429,7 +429,7 @@ def ck_moe_2stages( ): quant_func = get_hip_quant(quant_type) - q_dtype_a = w1.dtype if w1.dtype != torch.uint32 else torch.float8_e4m3fnuz + q_dtype_a = w1.dtype if w1.dtype != torch.uint32 else dtypes.fp8 # quant_func = get_torch_quant(quant_type) E, model_dim, inter_dim = w2.shape From c9ecccc6dfaad0973be3e78fd410458e1949d7fe Mon Sep 17 00:00:00 2001 From: "Lin, Soga" Date: Fri, 14 Nov 2025 12:47:18 +0000 Subject: [PATCH 02/15] fix triton compile issue --- .../triton/batched_gemm_afp4wfp4_pre_quant.py | 27 ++++++++++++---- aiter/ops/triton/gemm_a16w16_atomic.py | 23 +++++++++++--- aiter/ops/triton/gemm_afp4wfp4.py | 31 ++++++++++++++----- .../triton/gemm_afp4wfp4_pre_quant_atomic.py | 26 ++++++++++++---- 4 files changed, 83 insertions(+), 24 deletions(-) diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py index 8679344856..e316ab6d50 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py @@ -3,6 +3,7 @@ from typing import Optional import torch +from torch import Tensor import triton import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton._triton_kernels.batched_gemm_afp4wfp4_pre_quant import ( @@ -11,6 +12,7 @@ _get_config, ) from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() @@ -23,14 +25,25 @@ def set_use_gemm_splitk_bf16(value: bool): _USE_GEMM_SPLITK_BF16 = value +def batched_gemm_afp4wfp4_pre_quant_fake_tensor( + x: Tensor, + w: Tensor, + x_scales: Tensor, + w_scales: Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, +) -> Tensor: + return y + +@torch_compile_guard(gen_fake=batched_gemm_afp4wfp4_pre_quant_fake_tensor) def batched_gemm_afp4wfp4_pre_quant( - x, - w, - w_scales, - dtype: Optional[float] = torch.bfloat16, + x: Tensor, + w: Tensor, + w_scales: Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, -): + #config: Optional[dict] = None, +) -> Tensor: """ Computes batched FP4 matrix multiplication Y[i] = X[i] @ W[i]^T with active activation quantization. X is quantized to MXFP4 during computation, W is pre-quantized FP4. @@ -61,6 +74,8 @@ def batched_gemm_afp4wfp4_pre_quant( assert Bx == Bw == By Batch = Bx + config = {} + config = None if config is None: config = _get_config(M, N, K) diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index 78026c80f0..c4cea27964 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -3,6 +3,7 @@ from typing import Optional import torch +from torch import Tensor import triton import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info @@ -11,17 +12,27 @@ _get_config, ) from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() +def gemm_a16w16_atomic_fake_tensor( + x: Tensor, + w: Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, +) -> Tensor: + return y + +@torch_compile_guard(gen_fake=gemm_a16w16_atomic_fake_tensor) def gemm_a16w16_atomic( - x, - w, - dtype: Optional[float] = torch.bfloat16, + x: Tensor, + w: Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, -): + #config: Optional[dict] = None, +) -> Tensor: """ Computes 16 bit matrix multiplication Y = X @ W^T using atomic operations for split-K reduction. @@ -48,6 +59,8 @@ def gemm_a16w16_atomic( M, K = x.shape K, N = w.shape + config = {} + config = None if config is None: config = _get_config(M, N, K) # For compatability reasons, these keys may not exist in the config diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index a5353b9051..8b4ff38a92 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -3,9 +3,11 @@ from typing import Optional import torch +from torch import Tensor import triton import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info +from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.gemm_afp4wfp4 import ( _gemm_afp4_wfp4_kernel, @@ -62,16 +64,29 @@ def get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int): return SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT +def gemm_afp4wfp4_fake_tensor( + x: Tensor, + w: Tensor, + x_scales: Tensor, + w_scales: Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, +) -> Tensor: + M, K = x.shape + N, K = w.shape + out = torch.empty((M, N), dtype=dtype, device=x.device) + return out +@torch_compile_guard(gen_fake=gemm_afp4wfp4_fake_tensor) def gemm_afp4wfp4( - x, - w, - x_scales, - w_scales, - dtype: Optional[float] = torch.bfloat16, + x: Tensor, + w: Tensor, + x_scales: Tensor, + w_scales: Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, -): + #config: Optional[dict] = None, +) -> Tensor: """ Computes matrix multiplication Y = X @ W^T with FP4 activations and FP4 weights. @@ -106,6 +121,8 @@ def gemm_afp4wfp4( if y is None: y = torch.empty((M, N), dtype=dtype, device=x.device) + config = {} + config = None if config is None: config = _get_config(M, N, K) diff --git a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py index 94369cc2c8..0cd77d7025 100644 --- a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py +++ b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py @@ -3,6 +3,7 @@ from typing import Optional import torch +from torch import Tensor import triton import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info @@ -12,18 +13,29 @@ _gemm_afp4_wfp4_pre_quant_kernel, _get_config, ) +from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() +def gemm_afp4wfp4_pre_quant_fake_tensor( + x: Tensor, + w: Tensor, + w_scales: Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, +) -> Tensor: + return y + +@torch_compile_guard(gen_fake=gemm_afp4wfp4_pre_quant_fake_tensor) def gemm_afp4wfp4_pre_quant( - x, - w, - w_scales, - dtype: Optional[float] = torch.bfloat16, + x: Tensor, + w: Tensor, + w_scales: Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, -): +# config: Optional[dict] = None, +) -> Tensor: """ Computes matrix multiplication Y = X @ W^T with on-the-fly FP4 quantization of activations. X is quantized to MXFP4 during computation, W is pre-quantized FP4. Uses atomic operations for split-K reduction. @@ -59,6 +71,8 @@ def gemm_afp4wfp4_pre_quant( if y is None: y = torch.zeros((M, N), dtype=dtype, device=x.device) + config = {} + config = None if config is None: config = _get_config(M, N, K) From 95924b3b8fefacdccbfb2b551b16508063722eaa Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Thu, 20 Nov 2025 09:27:56 +0000 Subject: [PATCH 03/15] keep config --- .../triton/batched_gemm_afp4wfp4_pre_quant.py | 67 ++++++++++++------- 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py index e316ab6d50..f462233013 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -from typing import Optional +from typing import Optional, Dict, Any import torch from torch import Tensor import triton @@ -24,44 +24,34 @@ def set_use_gemm_splitk_bf16(value: bool): global _USE_GEMM_SPLITK_BF16 _USE_GEMM_SPLITK_BF16 = value +def serialize_dict(d: Dict[str, Any]) -> str: + items_list = list(d.items()) + sorted_items = sorted(items_list) + return json.dumps(sorted_items) + +def deserialize_string(s: str) -> Dict[str, Any]: + items_list = json.loads(s) + return dict(items_list) def batched_gemm_afp4wfp4_pre_quant_fake_tensor( x: Tensor, w: Tensor, - x_scales: Tensor, w_scales: Tensor, dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, + config: Optional[str] = None, ) -> Tensor: return y @torch_compile_guard(gen_fake=batched_gemm_afp4wfp4_pre_quant_fake_tensor) -def batched_gemm_afp4wfp4_pre_quant( +def batched_gemm_afp4wfp4_pre_quant_( x: Tensor, w: Tensor, w_scales: Tensor, dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, - #config: Optional[dict] = None, + config: Optional[str] = None, ) -> Tensor: - """ - Computes batched FP4 matrix multiplication Y[i] = X[i] @ W[i]^T with active activation quantization. - X is quantized to MXFP4 during computation, W is pre-quantized FP4. - - Args: - x (torch.Tensor): Higher precision input batch with shape (B, M, K) (BF16 or FP16). - Quantized to MXFP4 on-the-fly during GEMM. - w (torch.Tensor): FP4 E2M1 weight batch with shape (B, N, K), internally transposed. - w_scales (torch.Tensor): E8M0 per-group scale for w with shape (B, N, K//32). - One scale per 32 elements in K dimension. - dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). - y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (B, M, N). - config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, - BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). - - Returns: - torch.Tensor: Output batch with shape (B, M, N). - """ _LOGGER.info( f"BATCHED_GEMM_AFP4WFP_PREQUANT: x={tuple(x.shape)} w={tuple(w.shape)} w_scale={tuple(w.shape)}" ) @@ -74,10 +64,10 @@ def batched_gemm_afp4wfp4_pre_quant( assert Bx == Bw == By Batch = Bx - config = {} - config = None if config is None: config = _get_config(M, N, K) + else: + config = deserialize_string(config) if config["NUM_KSPLIT"] > 1: SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( @@ -169,3 +159,32 @@ def batched_gemm_afp4wfp4_pre_quant( config["NUM_KSPLIT"], ) return y + +def batched_gemm_afp4wfp4_pre_quant( + x: Tensor, + w: Tensor, + w_scales: Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, +) -> Tensor: + """ + Computes batched FP4 matrix multiplication Y[i] = X[i] @ W[i]^T with active activation quantization. + X is quantized to MXFP4 during computation, W is pre-quantized FP4. + + Args: + x (torch.Tensor): Higher precision input batch with shape (B, M, K) (BF16 or FP16). + Quantized to MXFP4 on-the-fly during GEMM. + w (torch.Tensor): FP4 E2M1 weight batch with shape (B, N, K), internally transposed. + w_scales (torch.Tensor): E8M0 per-group scale for w with shape (B, N, K//32). + One scale per 32 elements in K dimension. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (B, M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). + + Returns: + torch.Tensor: Output batch with shape (B, M, N). + """ + config_hashable = serialize_dict(config) if config else None + return batched_gemm_afp4wfp4_pre_quant_(x, w, w_scales, dtype, y, config_hashable) From 109e297f25e86aa3f1d49eb4fdc357c3adf1e1b7 Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Fri, 21 Nov 2025 03:36:09 +0000 Subject: [PATCH 04/15] add for other 3 functions --- .../triton/batched_gemm_afp4wfp4_pre_quant.py | 11 +--- aiter/ops/triton/gemm_a16w16_atomic.py | 54 ++++++++++------- aiter/ops/triton/gemm_afp4wfp4.py | 59 +++++++++++------- .../triton/gemm_afp4wfp4_pre_quant_atomic.py | 60 +++++++++++-------- aiter/ops/triton/utils/common_utils.py | 11 +++- 5 files changed, 117 insertions(+), 78 deletions(-) diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py index f462233013..3ba819c09a 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -from typing import Optional, Dict, Any +from typing import Optional import torch from torch import Tensor import triton import aiter.ops.triton.utils._triton.arch_info as arch_info +from .utils.common_utils import serialize_dict, deserialize_string from aiter.ops.triton._triton_kernels.batched_gemm_afp4wfp4_pre_quant import ( _batched_gemm_afp4_wfp4_pre_quant_reduce_kernel, _batched_gemm_afp4_wfp4_pre_quant_kernel, @@ -23,14 +24,6 @@ def set_use_gemm_splitk_bf16(value: bool): global _USE_GEMM_SPLITK_BF16 _USE_GEMM_SPLITK_BF16 = value - -def serialize_dict(d: Dict[str, Any]) -> str: - items_list = list(d.items()) - sorted_items = sorted(items_list) - return json.dumps(sorted_items) - -def deserialize_string(s: str) -> Dict[str, Any]: - items_list = json.loads(s) return dict(items_list) def batched_gemm_afp4wfp4_pre_quant_fake_tensor( diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index c4cea27964..9465cdbb8c 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -7,6 +7,7 @@ import triton import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info +from .utils.common_utils import serialize_dict, deserialize_string from aiter.ops.triton._triton_kernels.gemm_a16w16_atomic import ( _gemm_a16_w16_atomic_kernel, _get_config, @@ -22,34 +23,18 @@ def gemm_a16w16_atomic_fake_tensor( w: Tensor, dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, + config: Optional[str] = None, ) -> Tensor: return y @torch_compile_guard(gen_fake=gemm_a16w16_atomic_fake_tensor) -def gemm_a16w16_atomic( +def gemm_a16w16_atomic_( x: Tensor, w: Tensor, dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, - #config: Optional[dict] = None, + config: Optional[str] = None, ) -> Tensor: - """ - Computes 16 bit matrix multiplication Y = X @ W^T using atomic operations for split-K reduction. - - Args: - x (torch.Tensor): Input matrix with shape (M, K). - w (torch.Tensor): Weight matrix with shape (N, K), internally transposed. - dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). - Note: BF16 atomic aggregation may have slight precision loss. - y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). - Must be zero-initialized for split-K (NUM_KSPLIT > 1). - config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, - BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, cache_modifier). - - Returns: - torch.Tensor: Output with shape (M, N). - """ - _LOGGER.info( f"GEMM_A16W16_ATOMIC: x.shape={tuple(x.shape)}, w.shape={tuple(w.shape)} " ) @@ -59,10 +44,11 @@ def gemm_a16w16_atomic( M, K = x.shape K, N = w.shape - config = {} - config = None if config is None: config = _get_config(M, N, K) + else: + config = deserialize_string(config) + # For compatability reasons, these keys may not exist in the config # TODO: This needs to be embedded in the configs later if "NUM_KSPLIT" not in config: @@ -102,3 +88,29 @@ def gemm_a16w16_atomic( ) return y + +def gemm_a16w16_atomic( + x: Tensor, + w: Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, +) -> Tensor: + """ + Computes 16 bit matrix multiplication Y = X @ W^T using atomic operations for split-K reduction. + + Args: + x (torch.Tensor): Input matrix with shape (M, K). + w (torch.Tensor): Weight matrix with shape (N, K), internally transposed. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + Note: BF16 atomic aggregation may have slight precision loss. + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + Must be zero-initialized for split-K (NUM_KSPLIT > 1). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, cache_modifier). + + Returns: + torch.Tensor: Output with shape (M, N). + """ + config_hashable = serialize_dict(config) if config else None + return gemm_a16w16_atomic_(x, w, dtype, y, config_hashable) diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index 8b4ff38a92..3021701169 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -7,6 +7,7 @@ import triton import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info +from .utils.common_utils import serialize_dict, deserialize_string from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.gemm_afp4wfp4 import ( @@ -71,6 +72,7 @@ def gemm_afp4wfp4_fake_tensor( w_scales: Tensor, dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, + config: Optional[str] = None, ) -> Tensor: M, K = x.shape N, K = w.shape @@ -78,34 +80,15 @@ def gemm_afp4wfp4_fake_tensor( return out @torch_compile_guard(gen_fake=gemm_afp4wfp4_fake_tensor) -def gemm_afp4wfp4( +def gemm_afp4wfp4_( x: Tensor, w: Tensor, x_scales: Tensor, w_scales: Tensor, dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, - #config: Optional[dict] = None, + config: Optional[str] = None, ) -> Tensor: - """ - Computes matrix multiplication Y = X @ W^T with FP4 activations and FP4 weights. - - Args: - x (torch.Tensor): FP4 E2M1 input matrix with shape (M, K). - w (torch.Tensor): FP4 E2M1 weight matrix with shape (N, K), internally transposed. - x_scales (torch.Tensor): E8M0 per-group scale for x with shape (M, K//32). - One scale per 32 elements in K dimension. - w_scales (torch.Tensor): E8M0 per-group scale for w with shape (N, K//32). - One scale per 32 elements in K dimension. - dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). - y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). - config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, - BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). - - Returns: - torch.Tensor: Output with shape (M, N). - """ - _LOGGER.info( f"GEMM_AFPWFP4: x.shape={tuple(x.shape)} w.shape={tuple(w.shape)} x_scale={tuple(x_scales.shape)} w_scale={tuple(w_scales.shape)} " ) @@ -121,10 +104,10 @@ def gemm_afp4wfp4( if y is None: y = torch.empty((M, N), dtype=dtype, device=x.device) - config = {} - config = None if config is None: config = _get_config(M, N, K) + else: + config = deserialize_string(config) if config["NUM_KSPLIT"] > 1: SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( @@ -515,3 +498,33 @@ def gemm_afp4wfp4_preshuffled_weight_scales( ) return y + +def gemm_afp4wfp4( + x: Tensor, + w: Tensor, + x_scales: Tensor, + w_scales: Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, +) -> Tensor: + """ + Computes matrix multiplication Y = X @ W^T with FP4 activations and FP4 weights. + + Args: + x (torch.Tensor): FP4 E2M1 input matrix with shape (M, K). + w (torch.Tensor): FP4 E2M1 weight matrix with shape (N, K), internally transposed. + x_scales (torch.Tensor): E8M0 per-group scale for x with shape (M, K//32). + One scale per 32 elements in K dimension. + w_scales (torch.Tensor): E8M0 per-group scale for w with shape (N, K//32). + One scale per 32 elements in K dimension. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). + + Returns: + torch.Tensor: Output with shape (M, N). + """ + config_hashable = serialize_dict(config) if config else None + return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable) diff --git a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py index 0cd77d7025..ad73a9057e 100644 --- a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py +++ b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py @@ -7,6 +7,7 @@ import triton import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info +from .utils.common_utils import serialize_dict, deserialize_string from aiter.ops.triton.quant import _mxfp4_quant_op from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.gemm_afp4wfp4_pre_quant_atomic import ( @@ -24,38 +25,19 @@ def gemm_afp4wfp4_pre_quant_fake_tensor( w_scales: Tensor, dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, + config: Optional[str] = None, ) -> Tensor: return y @torch_compile_guard(gen_fake=gemm_afp4wfp4_pre_quant_fake_tensor) -def gemm_afp4wfp4_pre_quant( +def gemm_afp4wfp4_pre_quant_( x: Tensor, w: Tensor, w_scales: Tensor, dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, -# config: Optional[dict] = None, + config: Optional[str] = None, ) -> Tensor: - """ - Computes matrix multiplication Y = X @ W^T with on-the-fly FP4 quantization of activations. - X is quantized to MXFP4 during computation, W is pre-quantized FP4. Uses atomic operations for split-K reduction. - - Args: - x (torch.Tensor): Higher precision input matrix with shape (M, K) (BF16 or FP16). - Quantized to FP4 E2M1 on-the-fly during GEMM. - w (torch.Tensor): FP4 E2M1 weight matrix with shape (N, K), internally transposed. - w_scales (torch.Tensor): E8M0 per-group scale for w with shape (N, K//32). - One scale per 32 elements in K dimension. - dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). - y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). - Must be zero-initialized for atomic operations. - config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, - BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT). - - Returns: - torch.Tensor: Output with shape (M, N). - """ - _LOGGER.info( f"GEMM_AFP4WFP4_PRE_QUANT_ATOMIC: x={tuple(x.shape)} w={tuple(w.shape)} w_scale={tuple(w_scales.shape)} " ) @@ -71,10 +53,10 @@ def gemm_afp4wfp4_pre_quant( if y is None: y = torch.zeros((M, N), dtype=dtype, device=x.device) - config = {} - config = None if config is None: config = _get_config(M, N, K) + else: + config = deserialize_string(config) grid = lambda META: ( # noqa: E731 ( @@ -104,3 +86,33 @@ def gemm_afp4wfp4_pre_quant( ) return y + +def gemm_afp4wfp4_pre_quant( + x: Tensor, + w: Tensor, + w_scales: Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, +) -> Tensor: + """ + Computes matrix multiplication Y = X @ W^T with on-the-fly FP4 quantization of activations. + X is quantized to MXFP4 during computation, W is pre-quantized FP4. Uses atomic operations for split-K reduction. + + Args: + x (torch.Tensor): Higher precision input matrix with shape (M, K) (BF16 or FP16). + Quantized to FP4 E2M1 on-the-fly during GEMM. + w (torch.Tensor): FP4 E2M1 weight matrix with shape (N, K), internally transposed. + w_scales (torch.Tensor): E8M0 per-group scale for w with shape (N, K//32). + One scale per 32 elements in K dimension. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + Must be zero-initialized for atomic operations. + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT). + + Returns: + torch.Tensor: Output with shape (M, N). + """ + config_hashable = serialize_dict(config) if config else None + return gemm_afp4wfp4_pre_quant_(x, w, w_scales, dtype, y, config_hashable) diff --git a/aiter/ops/triton/utils/common_utils.py b/aiter/ops/triton/utils/common_utils.py index 2da76efe38..4880c59d04 100644 --- a/aiter/ops/triton/utils/common_utils.py +++ b/aiter/ops/triton/utils/common_utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -from typing import List +from typing import List, Dict, Any import torch import triton @@ -34,3 +34,12 @@ def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor: if x.stride(-1) == 1: return x return x.contiguous() + +def serialize_dict(d: Dict[str, Any]) -> str: + items_list = list(d.items()) + sorted_items = sorted(items_list) + return json.dumps(sorted_items) + +def deserialize_string(s: str) -> Dict[str, Any]: + items_list = json.loads(s) + return dict(items_list) From c7bda54eb6f88981f7199c48905ce5a9520fd9bc Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Fri, 21 Nov 2025 03:41:44 +0000 Subject: [PATCH 05/15] Revert "Add gfx950 support in asm_moe" This reverts commit d0c445da26f1f7dc3cf6b399ba6165d3cbfa8a8a. --- aiter/fused_moe_bf16_asm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aiter/fused_moe_bf16_asm.py b/aiter/fused_moe_bf16_asm.py index 6e1ab038ac..81df5ea592 100755 --- a/aiter/fused_moe_bf16_asm.py +++ b/aiter/fused_moe_bf16_asm.py @@ -140,8 +140,8 @@ def asm_moe( 128, ), "asm_moe for block_scale only support (128, 128)" assert ( - w1.dtype == dtypes.fp8 - ), "asm_moe for block_scale only support float8_e4m3fnuz weight on gfx942 and float8_e4m3fn on gfx950" + w1.dtype == torch.float8_e4m3fnuz + ), "asm_moe for block_scale only support float8_e4m3fnuz weight" assert ( w2.shape[2] * 2 == w1.shape[1] ), "aiter moe for block_scale only support g1u1" @@ -150,7 +150,7 @@ def asm_moe( a1_q, a1_scale = pertoken_quant( hidden_states.view(-1, model_dim // scale_blk_k, scale_blk_k), - quant_dtype=dtypes.fp8, + quant_dtype=torch.float8_e4m3fnuz, ) a1_q = a1_q.view(-1, model_dim) a1_scale = a1_scale.squeeze(-1).t().contiguous() @@ -447,7 +447,7 @@ def ck_moe_2stages( ): quant_func = get_hip_quant(quant_type) - q_dtype_a = w1.dtype if w1.dtype != torch.uint32 else dtypes.fp8 + q_dtype_a = w1.dtype if w1.dtype != torch.uint32 else torch.float8_e4m3fnuz # quant_func = get_torch_quant(quant_type) E, model_dim, inter_dim = w2.shape From dbf6932eda0b84b5f274fd2a2e8d7cb7a3ea6e06 Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Fri, 21 Nov 2025 03:51:19 +0000 Subject: [PATCH 06/15] for consistency --- aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py | 2 +- aiter/ops/triton/gemm_a16w16_atomic.py | 2 +- aiter/ops/triton/gemm_afp4wfp4.py | 4 ++-- aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py index 3ba819c09a..3a2b57191e 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py @@ -6,13 +6,13 @@ from torch import Tensor import triton import aiter.ops.triton.utils._triton.arch_info as arch_info -from .utils.common_utils import serialize_dict, deserialize_string from aiter.ops.triton._triton_kernels.batched_gemm_afp4wfp4_pre_quant import ( _batched_gemm_afp4_wfp4_pre_quant_reduce_kernel, _batched_gemm_afp4_wfp4_pre_quant_kernel, _get_config, ) from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_string from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index 9465cdbb8c..4d660fb9ea 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -7,12 +7,12 @@ import triton import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info -from .utils.common_utils import serialize_dict, deserialize_string from aiter.ops.triton._triton_kernels.gemm_a16w16_atomic import ( _gemm_a16_w16_atomic_kernel, _get_config, ) from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_string from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index 3021701169..551a2dc17b 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -7,9 +7,8 @@ import triton import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info -from .utils.common_utils import serialize_dict, deserialize_string -from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_string from aiter.ops.triton._triton_kernels.gemm_afp4wfp4 import ( _gemm_afp4_wfp4_kernel, _gemm_afp4_wfp4_kernel_preshuffled_scales, @@ -18,6 +17,7 @@ _get_config, ) from .utils.core import AITER_TRITON_CONFIGS_PATH +from aiter.jit.utils.torch_guard import torch_compile_guard import os from aiter.utility.triton.triton_metadata_redirect import AOTMetadataContext diff --git a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py index ad73a9057e..5f7001345b 100644 --- a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py +++ b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py @@ -7,9 +7,9 @@ import triton import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info -from .utils.common_utils import serialize_dict, deserialize_string from aiter.ops.triton.quant import _mxfp4_quant_op from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_string from aiter.ops.triton._triton_kernels.gemm_afp4wfp4_pre_quant_atomic import ( _gemm_afp4_wfp4_pre_quant_kernel, _get_config, From a0d98b8a854f9869111f3fbf7dab96a876a89892 Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Fri, 21 Nov 2025 11:27:34 +0000 Subject: [PATCH 07/15] fix errors --- aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py | 1 - aiter/ops/triton/gemm_afp4wfp4.py | 4 ++-- aiter/ops/triton/utils/common_utils.py | 1 + 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py index 3a2b57191e..da83a4da97 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py @@ -24,7 +24,6 @@ def set_use_gemm_splitk_bf16(value: bool): global _USE_GEMM_SPLITK_BF16 _USE_GEMM_SPLITK_BF16 = value - return dict(items_list) def batched_gemm_afp4wfp4_pre_quant_fake_tensor( x: Tensor, diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index 551a2dc17b..d51899014e 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -74,8 +74,8 @@ def gemm_afp4wfp4_fake_tensor( y: Optional[torch.Tensor] = None, config: Optional[str] = None, ) -> Tensor: - M, K = x.shape - N, K = w.shape + M, _ = x.shape + N, _ = w.shape out = torch.empty((M, N), dtype=dtype, device=x.device) return out diff --git a/aiter/ops/triton/utils/common_utils.py b/aiter/ops/triton/utils/common_utils.py index 4880c59d04..4a4cb9604c 100644 --- a/aiter/ops/triton/utils/common_utils.py +++ b/aiter/ops/triton/utils/common_utils.py @@ -5,6 +5,7 @@ import torch import triton +import json def prev_power_of_2(x: int) -> int: From 4a9d464e4678dfee58d2710041069900cdce70e5 Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Mon, 24 Nov 2025 08:47:55 +0000 Subject: [PATCH 08/15] fix gen_fake issues considering y is None case --- aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py | 7 +++++++ aiter/ops/triton/gemm_a16w16_atomic.py | 4 ++++ aiter/ops/triton/gemm_afp4wfp4.py | 9 +++++---- aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py | 4 ++++ 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py index da83a4da97..438fd459ae 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py @@ -33,6 +33,10 @@ def batched_gemm_afp4wfp4_pre_quant_fake_tensor( y: Optional[torch.Tensor] = None, config: Optional[str] = None, ) -> Tensor: + if y is None: + Bx, M, _ = x.shape + _, N, _ = w.shape + return torch.empty(Bx, M, N), dtype=dtype, device=x.device) return y @torch_compile_guard(gen_fake=batched_gemm_afp4wfp4_pre_quant_fake_tensor) @@ -56,6 +60,9 @@ def batched_gemm_afp4wfp4_pre_quant_( assert Bx == Bw == By Batch = Bx + if y is None: + y = torch.empty((M, N), dtype=dtype, device=x.device) + if config is None: config = _get_config(M, N, K) else: diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index 4d660fb9ea..e67ea61855 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -25,6 +25,10 @@ def gemm_a16w16_atomic_fake_tensor( y: Optional[torch.Tensor] = None, config: Optional[str] = None, ) -> Tensor: + if y is None: + M, _ = x.shape + _, N = w.shape + return torch.zeros((M, N), dtype=dtype, device=x.device) return y @torch_compile_guard(gen_fake=gemm_a16w16_atomic_fake_tensor) diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index d51899014e..a41c41983a 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -74,10 +74,11 @@ def gemm_afp4wfp4_fake_tensor( y: Optional[torch.Tensor] = None, config: Optional[str] = None, ) -> Tensor: - M, _ = x.shape - N, _ = w.shape - out = torch.empty((M, N), dtype=dtype, device=x.device) - return out + if y is None: + M, _ = x.shape + N, _ = w.shape + return torch.empty((M, N), dtype=dtype, device=x.device) + return y @torch_compile_guard(gen_fake=gemm_afp4wfp4_fake_tensor) def gemm_afp4wfp4_( diff --git a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py index 5f7001345b..be65bcb0a8 100644 --- a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py +++ b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py @@ -27,6 +27,10 @@ def gemm_afp4wfp4_pre_quant_fake_tensor( y: Optional[torch.Tensor] = None, config: Optional[str] = None, ) -> Tensor: + if y is None: + M, _ = x.shape + N, _ = w.shape + return torch.zeros((M, N), dtype=dtype, device=x.device) return y @torch_compile_guard(gen_fake=gemm_afp4wfp4_pre_quant_fake_tensor) From 82e523f0e8e2df430434fa03710c0bf9bd800cb2 Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Mon, 24 Nov 2025 10:24:25 +0000 Subject: [PATCH 09/15] fix error --- aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py index 438fd459ae..3d27f3851e 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py @@ -36,7 +36,7 @@ def batched_gemm_afp4wfp4_pre_quant_fake_tensor( if y is None: Bx, M, _ = x.shape _, N, _ = w.shape - return torch.empty(Bx, M, N), dtype=dtype, device=x.device) + return torch.empty((Bx, M, N), dtype=dtype, device=x.device) return y @torch_compile_guard(gen_fake=batched_gemm_afp4wfp4_pre_quant_fake_tensor) From ae2022a83269e541fb5c9c39a12456689c60a800 Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Mon, 15 Dec 2025 06:04:52 -0600 Subject: [PATCH 10/15] fix conflicts --- aiter/ops/triton/batched_gemm_a16wfp4.py | 33 +++- .../triton/batched_gemm_afp4wfp4_pre_quant.py | 176 +----------------- aiter/ops/triton/gemm_a16w16_atomic.py | 4 +- aiter/ops/triton/gemm_a16wfp4.py | 27 ++- aiter/ops/triton/gemm_afp4wfp4.py | 88 +++------ .../triton/gemm_afp4wfp4_pre_quant_atomic.py | 118 +----------- aiter/ops/triton/utils/common_utils.py | 3 +- 7 files changed, 98 insertions(+), 351 deletions(-) diff --git a/aiter/ops/triton/batched_gemm_a16wfp4.py b/aiter/ops/triton/batched_gemm_a16wfp4.py index a10cc66bea..d3c05b3f49 100755 --- a/aiter/ops/triton/batched_gemm_a16wfp4.py +++ b/aiter/ops/triton/batched_gemm_a16wfp4.py @@ -11,9 +11,11 @@ _get_config, ) from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.utils.common_utils import deserialize_dict from aiter.ops.triton.gemm_a16wfp4 import ( get_splitk, ) +from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() @@ -25,18 +27,35 @@ def set_use_gemm_splitk_bf16(value: bool): global _USE_GEMM_SPLITK_BF16 _USE_GEMM_SPLITK_BF16 = value +def batched_gemm_a16wfp4_fake_tensor( + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[str] = None, + transpose_bm: Optional[bool] = False, + prequant: Optional[bool] = True, + y_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if y is None: + Bx, M, _ = x.shape + _, N, _ = w.shape + return torch.empty((Bx, M, N), dtype=dtype, device=x.device) + return y +@torch_compile_guard(gen_fake=batched_gemm_a16wfp4_fake_tensor) def batched_gemm_a16wfp4( - x, - w, - w_scales, - dtype: Optional[float] = torch.bfloat16, + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, + config: Optional[str] = None, transpose_bm: Optional[bool] = False, prequant: Optional[bool] = True, y_scale: Optional[torch.Tensor] = None, -): +) -> torch.Tensor: """ Computes batched FP4 matrix multiplication Y[i] = X[i] @ W[i]^T with active activation quantization. X is quantized to MXFP4 during computation, W is pre-quantized FP4. @@ -72,6 +91,8 @@ def batched_gemm_a16wfp4( if config is None: config = _get_config(M, N, K) + else: + config = deserialize_dict(config) if y is None: if transpose_bm: diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py index b30b79e17c..05d3d00ec7 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py @@ -3,18 +3,13 @@ from typing import Optional import torch -from torch import Tensor import triton import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton.utils.logger import AiterTritonLogger -<<<<<<< HEAD -from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_string -from aiter.jit.utils.torch_guard import torch_compile_guard -======= +from aiter.ops.triton.utils.common_utils import serialize_dict from aiter.ops.triton.batched_gemm_a16wfp4 import ( batched_gemm_a16wfp4, ) ->>>>>>> main _LOGGER = AiterTritonLogger() @@ -26,176 +21,19 @@ def set_use_gemm_splitk_bf16(value: bool): global _USE_GEMM_SPLITK_BF16 _USE_GEMM_SPLITK_BF16 = value -def batched_gemm_afp4wfp4_pre_quant_fake_tensor( - x: Tensor, - w: Tensor, - w_scales: Tensor, - dtype: Optional[torch.dtype] = torch.bfloat16, - y: Optional[torch.Tensor] = None, -<<<<<<< HEAD - config: Optional[str] = None, -) -> Tensor: - if y is None: - Bx, M, _ = x.shape - _, N, _ = w.shape - return torch.empty((Bx, M, N), dtype=dtype, device=x.device) - return y - -@torch_compile_guard(gen_fake=batched_gemm_afp4wfp4_pre_quant_fake_tensor) -def batched_gemm_afp4wfp4_pre_quant_( - x: Tensor, - w: Tensor, - w_scales: Tensor, - dtype: Optional[torch.dtype] = torch.bfloat16, +def batched_gemm_afp4wfp4_pre_quant( + x, + w, + w_scales, + dtype: Optional[float] = torch.bfloat16, y: Optional[torch.Tensor] = None, - config: Optional[str] = None, -) -> Tensor: -======= config: Optional[dict] = None, ): ->>>>>>> main _LOGGER.info( "batched_gemm_afp4wfp4_pre_quant will be deprecated in future AITER release, please switch to batched_gemm_a16wfp4" ) -<<<<<<< HEAD - - assert arch_info.is_fp4_avail(), "MXFP4 is not available on your device" - - Bx, M, K = x.shape - Bw, N, K = w.shape - By, _, _ = y.shape - assert Bx == Bw == By - Batch = Bx - - if y is None: - y = torch.empty((M, N), dtype=dtype, device=x.device) - - if config is None: - config = _get_config(M, N, K) - else: - config = deserialize_string(config) - - if config["NUM_KSPLIT"] > 1: - SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( - K, config["BLOCK_SIZE_K"], config["NUM_KSPLIT"] - ) - - config["SPLITK_BLOCK_SIZE"] = SPLITK_BLOCK_SIZE - config["BLOCK_SIZE_K"] = BLOCK_SIZE_K - config["NUM_KSPLIT"] = NUM_KSPLIT - - if _USE_GEMM_SPLITK_BF16: - y_pp = torch.empty( - (Batch, config["NUM_KSPLIT"], M, N), dtype=y.dtype, device=y.device - ) - else: - y_pp = torch.empty( - (Batch, config["NUM_KSPLIT"], M, N), - dtype=torch.float32, - device=y.device, - ) - else: - config["SPLITK_BLOCK_SIZE"] = 2 * K - y_pp = None - - if config["BLOCK_SIZE_K"] >= 2 * K: - config["BLOCK_SIZE_K"] = triton.next_power_of_2(2 * K) - config["SPLITK_BLOCK_SIZE"] = 2 * K - - grid = lambda META: ( # noqa: E731 - Batch, - ( - META["NUM_KSPLIT"] - * triton.cdiv(M, META["BLOCK_SIZE_M"]) - * triton.cdiv(N, META["BLOCK_SIZE_N"]) - ), - ) - _batched_gemm_afp4_wfp4_pre_quant_kernel[grid]( - x, - w, - y if config["NUM_KSPLIT"] == 1 else y_pp, - w_scales, - M, - N, - K, - x.stride(0), - x.stride(1), - x.stride(2), - w.stride(0), - w.stride(1), - w.stride(2), - y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), - 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), - y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), - y.stride(2) if config["NUM_KSPLIT"] == 1 else y_pp.stride(3), - w_scales.stride(0), - w_scales.stride(1), - w_scales.stride(2), - **config, - ) - - if config["NUM_KSPLIT"] > 1: - REDUCE_BLOCK_SIZE_M = 16 - # TODO: Need to debug - REDUCE_BLOCK_SIZE_N=128 with fp32 partials fails - # NOTE: REDUCE_BLOCK_SIZE_N=16 gives best perf with fp32 partials and - # REDUCE_BLOCK_SIZE_N=128 gives best perf with bf16 partials - REDUCE_BLOCK_SIZE_N = 128 if _USE_GEMM_SPLITK_BF16 else 64 - ACTUAL_KSPLIT = triton.cdiv(K, (config["SPLITK_BLOCK_SIZE"] // 2)) - - grid_reduce = ( - Batch, - triton.cdiv(M, REDUCE_BLOCK_SIZE_M), - triton.cdiv(N, REDUCE_BLOCK_SIZE_N), - ) - _batched_gemm_afp4_wfp4_pre_quant_reduce_kernel[grid_reduce]( - y_pp, - y, - M, - N, - y_pp.stride(0), - y_pp.stride(1), - y_pp.stride(2), - y_pp.stride(3), - y.stride(0), - y.stride(1), - y.stride(2), - REDUCE_BLOCK_SIZE_M, - REDUCE_BLOCK_SIZE_N, - ACTUAL_KSPLIT, - config["NUM_KSPLIT"], - ) - return y - -def batched_gemm_afp4wfp4_pre_quant( - x: Tensor, - w: Tensor, - w_scales: Tensor, - dtype: Optional[torch.dtype] = torch.bfloat16, - y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, -) -> Tensor: - """ - Computes batched FP4 matrix multiplication Y[i] = X[i] @ W[i]^T with active activation quantization. - X is quantized to MXFP4 during computation, W is pre-quantized FP4. - - Args: - x (torch.Tensor): Higher precision input batch with shape (B, M, K) (BF16 or FP16). - Quantized to MXFP4 on-the-fly during GEMM. - w (torch.Tensor): FP4 E2M1 weight batch with shape (B, N, K), internally transposed. - w_scales (torch.Tensor): E8M0 per-group scale for w with shape (B, N, K//32). - One scale per 32 elements in K dimension. - dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). - y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (B, M, N). - config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, - BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). - Returns: - torch.Tensor: Output batch with shape (B, M, N). - """ config_hashable = serialize_dict(config) if config else None - return batched_gemm_afp4wfp4_pre_quant_(x, w, w_scales, dtype, y, config_hashable) -======= return batched_gemm_a16wfp4( - x, w, w_scales, dtype, y, config, transpose_bm=False, prequant=True + x, w, w_scales, dtype, y, config_hashable, transpose_bm=False, prequant=True ) ->>>>>>> main diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index e67ea61855..3e0a487add 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -12,7 +12,7 @@ _get_config, ) from aiter.ops.triton.utils.logger import AiterTritonLogger -from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_string +from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_dict from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() @@ -51,7 +51,7 @@ def gemm_a16w16_atomic_( if config is None: config = _get_config(M, N, K) else: - config = deserialize_string(config) + config = deserialize_dict(config) # For compatability reasons, these keys may not exist in the config # TODO: This needs to be embedded in the configs later diff --git a/aiter/ops/triton/gemm_a16wfp4.py b/aiter/ops/triton/gemm_a16wfp4.py index 40744fba68..b3aa46c4ba 100644 --- a/aiter/ops/triton/gemm_a16wfp4.py +++ b/aiter/ops/triton/gemm_a16wfp4.py @@ -8,6 +8,7 @@ import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton.quant import _mxfp4_quant_op from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.utils.common_utils import deserialize_dict from aiter.ops.triton._triton_kernels.gemm_a16wfp4 import ( _gemm_a16wfp4_kernel, _get_config, @@ -18,19 +19,35 @@ from aiter.ops.triton.gemm_afp4wfp4 import ( get_splitk, ) +from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() +def gemm_a16wfp4_fake_tensor( + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, + atomic_add: bool = False, + dtype: Optional[float] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[str] = None, +) -> torch.Tensor: + if y is None: + M, _ = x.shape + N, _ = w.shape + return torch.zeros((M, N), dtype=dtype, device=x.device) + return y +@torch_compile_guard(gen_fake=gemm_a16wfp4_fake_tensor) def gemm_a16wfp4( - x, - w, - w_scales, + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, atomic_add: bool = False, dtype: Optional[float] = torch.bfloat16, y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, + config: Optional[str] = None, ): """ Computes the matmul Y = X x W @@ -62,6 +79,8 @@ def gemm_a16wfp4( if config is None: config = _get_config(M, N, K) + else: + config = deserialize_dict(config) if y is None: if atomic_add: diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index 8d26cd2496..519bcf6dab 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -3,12 +3,11 @@ from typing import Optional import torch -from torch import Tensor import triton import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton.utils.logger import AiterTritonLogger -from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_string +from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_dict from aiter.ops.triton._triton_kernels.gemm_afp4wfp4 import ( _gemm_afp4wfp4_kernel, _gemm_afp4wfp4_kernel_preshuffle_scales, @@ -66,24 +65,32 @@ def get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int): return SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT def gemm_afp4wfp4_fake_tensor( - x: Tensor, - w: Tensor, - x_scales: Tensor, - w_scales: Tensor, + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, -<<<<<<< HEAD config: Optional[str] = None, -) -> Tensor: + skip_reduce: Optional[bool] = False, +) -> torch.Tensor: if y is None: M, _ = x.shape N, _ = w.shape return torch.empty((M, N), dtype=dtype, device=x.device) return y -======= - config: Optional[dict] = None, + +@torch_compile_guard(gen_fake=gemm_afp4wfp4_fake_tensor) +def gemm_afp4wfp4_( + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[str] = None, skip_reduce: Optional[bool] = False, -): +) -> torch.Tensor: """ Computes matrix multiplication Y = X @ W^T with FP4 activations and FP4 weights. @@ -102,18 +109,6 @@ def gemm_afp4wfp4_fake_tensor( Returns: torch.Tensor: Output with shape (M, N). """ ->>>>>>> main - -@torch_compile_guard(gen_fake=gemm_afp4wfp4_fake_tensor) -def gemm_afp4wfp4_( - x: Tensor, - w: Tensor, - x_scales: Tensor, - w_scales: Tensor, - dtype: Optional[torch.dtype] = torch.bfloat16, - y: Optional[torch.Tensor] = None, - config: Optional[str] = None, -) -> Tensor: _LOGGER.info( f"GEMM_AFPWFP4: x.shape={tuple(x.shape)} w.shape={tuple(w.shape)} x_scale={tuple(x_scales.shape)} w_scale={tuple(w_scales.shape)} " ) @@ -129,7 +124,7 @@ def gemm_afp4wfp4_( if config is None: config = _get_config(M, N, K) else: - config = deserialize_string(config) + config = deserialize_dict(config) if config["NUM_KSPLIT"] > 1: SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( @@ -536,38 +531,6 @@ def gemm_afp4wfp4_preshuffle( return y -<<<<<<< HEAD -def gemm_afp4wfp4( - x: Tensor, - w: Tensor, - x_scales: Tensor, - w_scales: Tensor, - dtype: Optional[torch.dtype] = torch.bfloat16, - y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, -) -> Tensor: - """ - Computes matrix multiplication Y = X @ W^T with FP4 activations and FP4 weights. - - Args: - x (torch.Tensor): FP4 E2M1 input matrix with shape (M, K). - w (torch.Tensor): FP4 E2M1 weight matrix with shape (N, K), internally transposed. - x_scales (torch.Tensor): E8M0 per-group scale for x with shape (M, K//32). - One scale per 32 elements in K dimension. - w_scales (torch.Tensor): E8M0 per-group scale for w with shape (N, K//32). - One scale per 32 elements in K dimension. - dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). - y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). - config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, - BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). - - Returns: - torch.Tensor: Output with shape (M, N). - """ - config_hashable = serialize_dict(config) if config else None - return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable) -======= - def gemm_afp4wfp4_preshuffled_weight_scales( x, w, @@ -582,4 +545,15 @@ def gemm_afp4wfp4_preshuffled_weight_scales( "gemm_afp4wfp4_preshuffled_weight_scales will be deprecated in future AITER release, please switch to gemm_afp4wfp4_preshuffle" ) return gemm_afp4wfp4_preshuffle(x, w, x_scales, w_scales, dtype, y, config, use_aot) ->>>>>>> main + +def gemm_afp4wfp4( + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, +) -> torch.Tensor: + config_hashable = serialize_dict(config) if config else None + return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable) diff --git a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py index d6a08c0c03..180a104622 100644 --- a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py +++ b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py @@ -3,132 +3,28 @@ from typing import Optional import torch -from torch import Tensor import triton import triton.language as tl from aiter.ops.triton.utils.logger import AiterTritonLogger -<<<<<<< HEAD -from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_string -from aiter.ops.triton._triton_kernels.gemm_afp4wfp4_pre_quant_atomic import ( - _gemm_afp4_wfp4_pre_quant_kernel, - _get_config, -======= +from aiter.ops.triton.utils.common_utils import serialize_dict from aiter.ops.triton.gemm_a16wfp4 import ( gemm_a16wfp4, ->>>>>>> main ) from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() - -def gemm_afp4wfp4_pre_quant_fake_tensor( - x: Tensor, - w: Tensor, - w_scales: Tensor, - dtype: Optional[torch.dtype] = torch.bfloat16, - y: Optional[torch.Tensor] = None, -<<<<<<< HEAD - config: Optional[str] = None, -) -> Tensor: - if y is None: - M, _ = x.shape - N, _ = w.shape - return torch.zeros((M, N), dtype=dtype, device=x.device) - return y - -@torch_compile_guard(gen_fake=gemm_afp4wfp4_pre_quant_fake_tensor) -def gemm_afp4wfp4_pre_quant_( - x: Tensor, - w: Tensor, - w_scales: Tensor, - dtype: Optional[torch.dtype] = torch.bfloat16, +def gemm_afp4wfp4_pre_quant( + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[float] = torch.bfloat16, y: Optional[torch.Tensor] = None, - config: Optional[str] = None, -) -> Tensor: -======= config: Optional[dict] = None, ): ->>>>>>> main _LOGGER.info( "gemm_afp4wfp4_pre_quant will be deprecated in future AITER release, please switch to gemm_a16wfp4" ) -<<<<<<< HEAD - - assert arch_info.is_fp4_avail(), "MXFP4 is not available on your device" - - M, K = x.shape - N, K = w.shape - - # inner kernel expects (K, N) - w = w.T - - if y is None: - y = torch.zeros((M, N), dtype=dtype, device=x.device) - - if config is None: - config = _get_config(M, N, K) - else: - config = deserialize_string(config) - - grid = lambda META: ( # noqa: E731 - ( - META["NUM_KSPLIT"] - * triton.cdiv(M, META["BLOCK_SIZE_M"]) - * triton.cdiv(N, META["BLOCK_SIZE_N"]) - ), - ) - _gemm_afp4_wfp4_pre_quant_kernel[grid]( - x, - w, - y, - w_scales, - M, - N, - K, - x.stride(0), - x.stride(1), - w.stride(0), - w.stride(1), - 0, - y.stride(0), - y.stride(1), - w_scales.stride(0), - w_scales.stride(1), - **config, - ) - - return y - -def gemm_afp4wfp4_pre_quant( - x: Tensor, - w: Tensor, - w_scales: Tensor, - dtype: Optional[torch.dtype] = torch.bfloat16, - y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, -) -> Tensor: - """ - Computes matrix multiplication Y = X @ W^T with on-the-fly FP4 quantization of activations. - X is quantized to MXFP4 during computation, W is pre-quantized FP4. Uses atomic operations for split-K reduction. - - Args: - x (torch.Tensor): Higher precision input matrix with shape (M, K) (BF16 or FP16). - Quantized to FP4 E2M1 on-the-fly during GEMM. - w (torch.Tensor): FP4 E2M1 weight matrix with shape (N, K), internally transposed. - w_scales (torch.Tensor): E8M0 per-group scale for w with shape (N, K//32). - One scale per 32 elements in K dimension. - dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). - y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). - Must be zero-initialized for atomic operations. - config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, - BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT). - Returns: - torch.Tensor: Output with shape (M, N). - """ config_hashable = serialize_dict(config) if config else None - return gemm_afp4wfp4_pre_quant_(x, w, w_scales, dtype, y, config_hashable) -======= - return gemm_a16wfp4(x, w, w_scales, True, dtype, y, config) ->>>>>>> main + return gemm_a16wfp4(x, w, w_scales, True, dtype, y, config_hashable) diff --git a/aiter/ops/triton/utils/common_utils.py b/aiter/ops/triton/utils/common_utils.py index 4a4cb9604c..6c912f5479 100644 --- a/aiter/ops/triton/utils/common_utils.py +++ b/aiter/ops/triton/utils/common_utils.py @@ -7,7 +7,6 @@ import triton import json - def prev_power_of_2(x: int) -> int: out = triton.next_power_of_2(x) return out // 2 if out > x else out @@ -41,6 +40,6 @@ def serialize_dict(d: Dict[str, Any]) -> str: sorted_items = sorted(items_list) return json.dumps(sorted_items) -def deserialize_string(s: str) -> Dict[str, Any]: +def deserialize_dict(s: str) -> Dict[str, Any]: items_list = json.loads(s) return dict(items_list) From 24c2ada71257480c0cbdfbc3ac54d7a67cf8fc65 Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Thu, 18 Dec 2025 07:20:05 +0000 Subject: [PATCH 11/15] fix error --- aiter/ops/triton/gemm_a16wfp4.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aiter/ops/triton/gemm_a16wfp4.py b/aiter/ops/triton/gemm_a16wfp4.py index b3aa46c4ba..d1390ca986 100644 --- a/aiter/ops/triton/gemm_a16wfp4.py +++ b/aiter/ops/triton/gemm_a16wfp4.py @@ -29,7 +29,7 @@ def gemm_a16wfp4_fake_tensor( w: torch.Tensor, w_scales: torch.Tensor, atomic_add: bool = False, - dtype: Optional[float] = torch.bfloat16, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[str] = None, ) -> torch.Tensor: @@ -45,10 +45,10 @@ def gemm_a16wfp4( w: torch.Tensor, w_scales: torch.Tensor, atomic_add: bool = False, - dtype: Optional[float] = torch.bfloat16, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[str] = None, -): +) -> torch.Tensor: """ Computes the matmul Y = X x W W is an e2m1 fp4 tensor and w_scales is an e8m0 tensor. From a596a628acd1771096dc3ffa507ce52223f36e32 Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Thu, 18 Dec 2025 07:55:54 +0000 Subject: [PATCH 12/15] simplify serialization --- aiter/ops/triton/batched_gemm_a16wfp4.py | 4 ++-- aiter/ops/triton/gemm_a16w16_atomic.py | 4 ++-- aiter/ops/triton/gemm_a16wfp4.py | 4 ++-- aiter/ops/triton/gemm_afp4wfp4.py | 4 ++-- aiter/ops/triton/utils/common_utils.py | 13 +++++-------- 5 files changed, 13 insertions(+), 16 deletions(-) diff --git a/aiter/ops/triton/batched_gemm_a16wfp4.py b/aiter/ops/triton/batched_gemm_a16wfp4.py index d3c05b3f49..58d5604147 100755 --- a/aiter/ops/triton/batched_gemm_a16wfp4.py +++ b/aiter/ops/triton/batched_gemm_a16wfp4.py @@ -11,7 +11,7 @@ _get_config, ) from aiter.ops.triton.utils.logger import AiterTritonLogger -from aiter.ops.triton.utils.common_utils import deserialize_dict +from aiter.ops.triton.utils.common_utils import deserialize_str from aiter.ops.triton.gemm_a16wfp4 import ( get_splitk, ) @@ -92,7 +92,7 @@ def batched_gemm_a16wfp4( if config is None: config = _get_config(M, N, K) else: - config = deserialize_dict(config) + config = deserialize_str(config) if y is None: if transpose_bm: diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index 3e0a487add..8d1e0e4b27 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -12,7 +12,7 @@ _get_config, ) from aiter.ops.triton.utils.logger import AiterTritonLogger -from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_dict +from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_str from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() @@ -51,7 +51,7 @@ def gemm_a16w16_atomic_( if config is None: config = _get_config(M, N, K) else: - config = deserialize_dict(config) + config = deserialize_str(config) # For compatability reasons, these keys may not exist in the config # TODO: This needs to be embedded in the configs later diff --git a/aiter/ops/triton/gemm_a16wfp4.py b/aiter/ops/triton/gemm_a16wfp4.py index d1390ca986..a5e1b68eff 100644 --- a/aiter/ops/triton/gemm_a16wfp4.py +++ b/aiter/ops/triton/gemm_a16wfp4.py @@ -8,7 +8,7 @@ import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton.quant import _mxfp4_quant_op from aiter.ops.triton.utils.logger import AiterTritonLogger -from aiter.ops.triton.utils.common_utils import deserialize_dict +from aiter.ops.triton.utils.common_utils import deserialize_str from aiter.ops.triton._triton_kernels.gemm_a16wfp4 import ( _gemm_a16wfp4_kernel, _get_config, @@ -80,7 +80,7 @@ def gemm_a16wfp4( if config is None: config = _get_config(M, N, K) else: - config = deserialize_dict(config) + config = deserialize_str(config) if y is None: if atomic_add: diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index 519bcf6dab..f085ce7200 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -7,7 +7,7 @@ import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton.utils.logger import AiterTritonLogger -from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_dict +from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_str from aiter.ops.triton._triton_kernels.gemm_afp4wfp4 import ( _gemm_afp4wfp4_kernel, _gemm_afp4wfp4_kernel_preshuffle_scales, @@ -124,7 +124,7 @@ def gemm_afp4wfp4_( if config is None: config = _get_config(M, N, K) else: - config = deserialize_dict(config) + config = deserialize_str(config) if config["NUM_KSPLIT"] > 1: SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( diff --git a/aiter/ops/triton/utils/common_utils.py b/aiter/ops/triton/utils/common_utils.py index 6c912f5479..7521b27d61 100644 --- a/aiter/ops/triton/utils/common_utils.py +++ b/aiter/ops/triton/utils/common_utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -from typing import List, Dict, Any +from typing import List import torch import triton @@ -35,11 +35,8 @@ def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor: return x return x.contiguous() -def serialize_dict(d: Dict[str, Any]) -> str: - items_list = list(d.items()) - sorted_items = sorted(items_list) - return json.dumps(sorted_items) +def serialize_dict(d: dict) -> str: + return json.dumps(d) -def deserialize_dict(s: str) -> Dict[str, Any]: - items_list = json.loads(s) - return dict(items_list) +def deserialize_str(s: str) -> dict: + return json.loads(s) From ce07338cfd1ad2914f94a7e5db5630c0a25d523d Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Thu, 18 Dec 2025 08:29:03 +0000 Subject: [PATCH 13/15] for consistency --- aiter/ops/triton/gemm_a16w16_atomic.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index 8d1e0e4b27..0fcf45877e 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -3,7 +3,6 @@ from typing import Optional import torch -from torch import Tensor import triton import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info @@ -19,12 +18,12 @@ def gemm_a16w16_atomic_fake_tensor( - x: Tensor, - w: Tensor, + x: torch.Tensor, + w: torch.Tensor, dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[str] = None, -) -> Tensor: +) -> torch.Tensor: if y is None: M, _ = x.shape _, N = w.shape @@ -33,12 +32,12 @@ def gemm_a16w16_atomic_fake_tensor( @torch_compile_guard(gen_fake=gemm_a16w16_atomic_fake_tensor) def gemm_a16w16_atomic_( - x: Tensor, - w: Tensor, + x: torch.Tensor, + w: torch.Tensor, dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[str] = None, -) -> Tensor: +) -> torch.Tensor: _LOGGER.info( f"GEMM_A16W16_ATOMIC: x.shape={tuple(x.shape)}, w.shape={tuple(w.shape)} " ) @@ -94,12 +93,12 @@ def gemm_a16w16_atomic_( return y def gemm_a16w16_atomic( - x: Tensor, - w: Tensor, + x: torch.Tensor, + w: torch.Tensor, dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[dict] = None, -) -> Tensor: +) -> torch.Tensor: """ Computes 16 bit matrix multiplication Y = X @ W^T using atomic operations for split-K reduction. From 90e00e43c6bc7b06b0e256c659de55675e81758c Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Fri, 19 Dec 2025 01:55:26 -0600 Subject: [PATCH 14/15] fix black failure --- aiter/ops/triton/batched_gemm_a16wfp4.py | 2 ++ aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py | 1 + aiter/ops/triton/gemm_a16w16_atomic.py | 2 ++ aiter/ops/triton/gemm_a16wfp4.py | 2 ++ aiter/ops/triton/gemm_afp4wfp4.py | 4 ++++ aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py | 1 + aiter/ops/triton/utils/common_utils.py | 3 +++ 7 files changed, 15 insertions(+) diff --git a/aiter/ops/triton/batched_gemm_a16wfp4.py b/aiter/ops/triton/batched_gemm_a16wfp4.py index 58d5604147..ffd8b0ba3d 100755 --- a/aiter/ops/triton/batched_gemm_a16wfp4.py +++ b/aiter/ops/triton/batched_gemm_a16wfp4.py @@ -27,6 +27,7 @@ def set_use_gemm_splitk_bf16(value: bool): global _USE_GEMM_SPLITK_BF16 _USE_GEMM_SPLITK_BF16 = value + def batched_gemm_a16wfp4_fake_tensor( x: torch.Tensor, w: torch.Tensor, @@ -44,6 +45,7 @@ def batched_gemm_a16wfp4_fake_tensor( return torch.empty((Bx, M, N), dtype=dtype, device=x.device) return y + @torch_compile_guard(gen_fake=batched_gemm_a16wfp4_fake_tensor) def batched_gemm_a16wfp4( x: torch.Tensor, diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py index 05d3d00ec7..2f3718cd71 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py @@ -21,6 +21,7 @@ def set_use_gemm_splitk_bf16(value: bool): global _USE_GEMM_SPLITK_BF16 _USE_GEMM_SPLITK_BF16 = value + def batched_gemm_afp4wfp4_pre_quant( x, w, diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index 0fcf45877e..38341b3efb 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -30,6 +30,7 @@ def gemm_a16w16_atomic_fake_tensor( return torch.zeros((M, N), dtype=dtype, device=x.device) return y + @torch_compile_guard(gen_fake=gemm_a16w16_atomic_fake_tensor) def gemm_a16w16_atomic_( x: torch.Tensor, @@ -92,6 +93,7 @@ def gemm_a16w16_atomic_( return y + def gemm_a16w16_atomic( x: torch.Tensor, w: torch.Tensor, diff --git a/aiter/ops/triton/gemm_a16wfp4.py b/aiter/ops/triton/gemm_a16wfp4.py index a5e1b68eff..a9279dfcf9 100644 --- a/aiter/ops/triton/gemm_a16wfp4.py +++ b/aiter/ops/triton/gemm_a16wfp4.py @@ -24,6 +24,7 @@ _LOGGER = AiterTritonLogger() + def gemm_a16wfp4_fake_tensor( x: torch.Tensor, w: torch.Tensor, @@ -39,6 +40,7 @@ def gemm_a16wfp4_fake_tensor( return torch.zeros((M, N), dtype=dtype, device=x.device) return y + @torch_compile_guard(gen_fake=gemm_a16wfp4_fake_tensor) def gemm_a16wfp4( x: torch.Tensor, diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index f085ce7200..fdebbdd3d9 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -64,6 +64,7 @@ def get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int): return SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT + def gemm_afp4wfp4_fake_tensor( x: torch.Tensor, w: torch.Tensor, @@ -80,6 +81,7 @@ def gemm_afp4wfp4_fake_tensor( return torch.empty((M, N), dtype=dtype, device=x.device) return y + @torch_compile_guard(gen_fake=gemm_afp4wfp4_fake_tensor) def gemm_afp4wfp4_( x: torch.Tensor, @@ -531,6 +533,7 @@ def gemm_afp4wfp4_preshuffle( return y + def gemm_afp4wfp4_preshuffled_weight_scales( x, w, @@ -546,6 +549,7 @@ def gemm_afp4wfp4_preshuffled_weight_scales( ) return gemm_afp4wfp4_preshuffle(x, w, x_scales, w_scales, dtype, y, config, use_aot) + def gemm_afp4wfp4( x: torch.Tensor, w: torch.Tensor, diff --git a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py index 180a104622..a247212639 100644 --- a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py +++ b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py @@ -14,6 +14,7 @@ _LOGGER = AiterTritonLogger() + def gemm_afp4wfp4_pre_quant( x: torch.Tensor, w: torch.Tensor, diff --git a/aiter/ops/triton/utils/common_utils.py b/aiter/ops/triton/utils/common_utils.py index 7521b27d61..4729ccfdb3 100644 --- a/aiter/ops/triton/utils/common_utils.py +++ b/aiter/ops/triton/utils/common_utils.py @@ -7,6 +7,7 @@ import triton import json + def prev_power_of_2(x: int) -> int: out = triton.next_power_of_2(x) return out // 2 if out > x else out @@ -35,8 +36,10 @@ def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor: return x return x.contiguous() + def serialize_dict(d: dict) -> str: return json.dumps(d) + def deserialize_str(s: str) -> dict: return json.loads(s) From 16749d8f2369514c78fe4095b1216435777550e3 Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Fri, 19 Dec 2025 02:14:51 -0600 Subject: [PATCH 15/15] fix ruff problems --- aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py | 2 -- aiter/ops/triton/gemm_a16wfp4.py | 2 -- aiter/ops/triton/gemm_afp4wfp4.py | 2 -- aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py | 3 --- 4 files changed, 9 deletions(-) diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py index 2f3718cd71..92a8b30256 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py @@ -3,8 +3,6 @@ from typing import Optional import torch -import triton -import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton.utils.common_utils import serialize_dict from aiter.ops.triton.batched_gemm_a16wfp4 import ( diff --git a/aiter/ops/triton/gemm_a16wfp4.py b/aiter/ops/triton/gemm_a16wfp4.py index a9279dfcf9..2bc0983119 100644 --- a/aiter/ops/triton/gemm_a16wfp4.py +++ b/aiter/ops/triton/gemm_a16wfp4.py @@ -4,9 +4,7 @@ from typing import Optional import torch import triton -import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info -from aiter.ops.triton.quant import _mxfp4_quant_op from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton.utils.common_utils import deserialize_str from aiter.ops.triton._triton_kernels.gemm_a16wfp4 import ( diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index fdebbdd3d9..1085dd5d12 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -4,13 +4,11 @@ from typing import Optional import torch import triton -import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_str from aiter.ops.triton._triton_kernels.gemm_afp4wfp4 import ( _gemm_afp4wfp4_kernel, - _gemm_afp4wfp4_kernel_preshuffle_scales, _gemm_afp4wfp4_preshuffle_kernel, _gemm_afp4wfp4_reduce_kernel, _get_config, diff --git a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py index a247212639..2d5cbe3e32 100644 --- a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py +++ b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py @@ -3,14 +3,11 @@ from typing import Optional import torch -import triton -import triton.language as tl from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton.utils.common_utils import serialize_dict from aiter.ops.triton.gemm_a16wfp4 import ( gemm_a16wfp4, ) -from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger()