diff --git a/aiter/ops/triton/gluon/gemm_a8w8_blockscale.py b/aiter/ops/triton/gluon/gemm_a8w8_blockscale.py index 340398cd74..9e65516366 100644 --- a/aiter/ops/triton/gluon/gemm_a8w8_blockscale.py +++ b/aiter/ops/triton/gluon/gemm_a8w8_blockscale.py @@ -120,7 +120,7 @@ def _gemm_a8w8_blockscale_kernel( ) mfma_layout: gl.constexpr = gl.amd.AMDMFMALayout( version=4, - instr_shape=[16, 16], + instr_shape=[16, 16, 32], # V_MFMA_F32_16X16X32_FP8_FP8 instruction transposed=True, warps_per_cta=[NUM_WARPS // 2, 2], ) diff --git a/op_tests/triton_tests/test_gemm_a8w8_blockscale.py b/op_tests/triton_tests/test_gemm_a8w8_blockscale.py index c2c981e8cb..f65cd4be10 100644 --- a/op_tests/triton_tests/test_gemm_a8w8_blockscale.py +++ b/op_tests/triton_tests/test_gemm_a8w8_blockscale.py @@ -183,11 +183,6 @@ def test_gemm(dtype, M, N, K, layout, output, impl: str): torch.cuda.synchronize() block_shape_n, block_shape_k = block_shape - if K % block_shape_k != 0: - pytest.skip( - "Latest upstream compiler as of Aug 22 (necessary for Gluon) causes" - " infinite hang when EVEN_K is false. Try seeing if it's fixed if it's been a while." - ) if impl == "gluon" and int(DEVICE_ARCH.split("MI")[1].replace("X", "")) < 350: pytest.skip(