diff --git a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py index b67dbb914b..918474fdcd 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py @@ -141,7 +141,9 @@ def _gemm_afp4_wfp4_kernel( cache_modifier=cache_modifier, ) - accumulator = tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1", accumulator) + accumulator = tl.dot_scaled( + a, a_scales, "e2m1", b, b_scales, "e2m1", accumulator + ) # Advance the ptrs to the next K block. a_ptrs += (BLOCK_SIZE_K // 2) * stride_ak @@ -340,7 +342,9 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales( b_ptrs, mask=offs_k[:, None] < K - k * (BLOCK_SIZE_K // 2), other=0 ) - accumulator = tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1", accumulator) + accumulator = tl.dot_scaled( + a, a_scales, "e2m1", b, b_scales, "e2m1", accumulator + ) # Advance the ptrs to the next K block. a_ptrs += (BLOCK_SIZE_K // 2) * stride_ak