Skip to content
2 changes: 1 addition & 1 deletion aiter/ops/triton/gluon/gemm_a8w8_blockscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _gemm_a8w8_blockscale_kernel(
)
mfma_layout: gl.constexpr = gl.amd.AMDMFMALayout(
version=4,
instr_shape=[16, 16],
instr_shape=[16, 16, 32], # V_MFMA_F32_16X16X32_FP8_FP8 instruction
transposed=True,
warps_per_cta=[NUM_WARPS // 2, 2],
)
Expand Down
5 changes: 0 additions & 5 deletions op_tests/triton_tests/test_gemm_a8w8_blockscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,6 @@ def test_gemm(dtype, M, N, K, layout, output, impl: str):
torch.cuda.synchronize()

block_shape_n, block_shape_k = block_shape
if K % block_shape_k != 0:
pytest.skip(
"Latest upstream compiler as of Aug 22 (necessary for Gluon) causes"
" infinite hang when EVEN_K is false. Try seeing if it's fixed if it's been a while."
)

if impl == "gluon" and int(DEVICE_ARCH.split("MI")[1].replace("X", "")) < 350:
pytest.skip(
Expand Down