diff --git a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py index b67dbb914b..918474fdcd 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py @@ -141,7 +141,9 @@ def _gemm_afp4_wfp4_kernel( cache_modifier=cache_modifier, ) - accumulator = tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1", accumulator) + accumulator = tl.dot_scaled( + a, a_scales, "e2m1", b, b_scales, "e2m1", accumulator + ) # Advance the ptrs to the next K block. a_ptrs += (BLOCK_SIZE_K // 2) * stride_ak @@ -340,7 +342,9 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales( b_ptrs, mask=offs_k[:, None] < K - k * (BLOCK_SIZE_K // 2), other=0 ) - accumulator = tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1", accumulator) + accumulator = tl.dot_scaled( + a, a_scales, "e2m1", b, b_scales, "e2m1", accumulator + ) # Advance the ptrs to the next K block. a_ptrs += (BLOCK_SIZE_K // 2) * stride_ak diff --git a/op_tests/test_mha.py b/op_tests/test_mha.py index 4e7adc777c..e2e3d80302 100644 --- a/op_tests/test_mha.py +++ b/op_tests/test_mha.py @@ -365,18 +365,26 @@ def test_flash_attn_output( dbias_tol = max(10 * (dbias_pt - dbias_ref).abs().max().item(), 0.01) assert (dbias - dbias_ref).abs().max().item() <= dbias_tol - fwd_flop = nheads * (seqlen_q * seqlen_k * d * 2 + seqlen_q * seqlen_k * d_v * 2) + fwd_flop = ( + batch_size + * nheads + * (seqlen_q * seqlen_k * d * 2 + seqlen_q * seqlen_k * d_v * 2) + ) dtype_bytes = torch.finfo(dtype).bits // 8 fwd_num_bytes = ( - nheads + batch_size + * nheads * dtype_bytes * (seqlen_q * d + seqlen_k * d + seqlen_k * d_v + seqlen_q * d_v) ) - bwd_flop = nheads * ( - seqlen_q * seqlen_k * d * 2 * 3 + seqlen_q * seqlen_k * d_v * 2 * 2 + bwd_flop = ( + batch_size + * nheads + * (seqlen_q * seqlen_k * d * 2 * 3 + seqlen_q * seqlen_k * d_v * 2 * 2) ) bwd_num_bytes = ( - 2 * fwd_num_bytes + nheads * (torch.finfo(torch.float).bits // 8) * seqlen_q + 2 * fwd_num_bytes + + batch_size * nheads * (torch.finfo(torch.float).bits // 8) * seqlen_q ) ret = {} ret["fwd_us"] = us_fwd