diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index e48cfca441d3..62499d8208cf 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -374,24 +374,24 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere) // -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be // completed before we can remove the layoutIsOK check: - // 1. Support for AMD's MFMA and WMMA + // 1. Support for AMD's WMMA std::function layoutIsOK = [&](Attribute layout) { - if (auto nvidiaMma = dyn_cast(layout)) { - if (useLegacyMMAConversion) { - return false; - } - return true; + if (isa(layout)) { + return !useLegacyMMAConversion; } if (auto dotOperand = dyn_cast(layout)) { - if (auto nvidiaMma = - dyn_cast(dotOperand.getParent())) { - if (useLegacyMMAConversion) { - return false; - } + auto parent = dotOperand.getParent(); + if (isa(parent) && useLegacyMMAConversion) { + return false; + } + if (auto nvidiaMma = dyn_cast(parent)) { if (nvidiaMma.isAmpere()) { return true; } } + if (isa(parent)) { + return true; + } return false; } if (isa(layout)) { diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 83c9e535d8c0..a2c8f48718d9 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -7,6 +7,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { // CHECK-NOT: store // CHECK-NOT: load + // CHECK: llvm.return %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } @@ -21,6 +22,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { // CHECK: store // CHECK: load + // CHECK: llvm.return %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return }