diff --git a/aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4_pre_quant.py index d94b3ae958..86f7748acf 100755 --- a/aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4_pre_quant.py @@ -4,14 +4,15 @@ import functools import json import os + import triton import triton.language as tl -from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd + from ..utils._triton import arch_info +from ..utils._triton.kernel_repr import make_kernel_repr +from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils.core import AITER_TRITON_CONFIGS_PATH from .quant import _mxfp4_quant_op -from ..utils._triton.kernel_repr import make_kernel_repr - _batched_gemm_afp4_wfp4_pre_quant_repr = make_kernel_repr( "_batched_gemm_afp4_wfp4_pre_quant_kernel", @@ -62,8 +63,8 @@ def _batched_gemm_afp4_wfp4_pre_quant_kernel( stride_am, stride_ak, stride_bb, - stride_bk, stride_bn, + stride_bk, stride_cb, stride_ck, stride_cm,