From efd0fffcff72d8e2129901499df40c9568a34539 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Tue, 25 Feb 2025 16:22:27 +0000 Subject: [PATCH 1/3] [AMD] [FA] Hoist convert_layout to dotOp for Q out of the loop This PR adds a new amd.pass that hoists conver_layout to dotOperand layout for the Q tensor out of the loop. Therefore, Q tensor is kept in registers instead of being loaded at every iteration of the loop. This PR is actually achieving the same thing as https://github.com/triton-lang/triton/pull/4901. However, https://github.com/triton-lang/triton/pull/4901 does not hoist local_load for Q in the epilogue, making Q tensor live in shared memory all the time. On the other hand, this PR does the trick before stream-pipeline pass. Therefore, the livessness of Q tensor in shared memory is limited in the prologue. --- bin/RegisterTritonDialects.h | 1 + .../amd/amd-reorder-instructions.mlir | 93 ------------------- third_party/amd/backend/compiler.py | 1 + .../include/TritonAMDGPUTransforms/Passes.h | 2 + .../include/TritonAMDGPUTransforms/Passes.td | 12 +++ .../lib/TritonAMDGPUTransforms/CMakeLists.txt | 1 + .../HoistLayoutConversions.cpp | 63 +++++++++++++ .../ReorderInstructions.cpp | 73 --------------- third_party/amd/python/triton_amd.cc | 2 + 9 files changed, 82 insertions(+), 166 deletions(-) create mode 100644 third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index ebd32ad4422f..3507b06c1e9f 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -60,6 +60,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { // TritonAMDGPUTransforms passes mlir::registerTritonAMDGPUAccelerateMatmul(); mlir::registerTritonAMDGPUOptimizeEpilogue(); + mlir::registerTritonAMDGPUHoistLayoutConversions(); mlir::registerTritonAMDGPUReorderInstructions(); mlir::registerTritonAMDGPUBlockPingpong(); mlir::registerTritonAMDGPUStreamPipeline(); diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index c5cb7dd24062..e481b674245a 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -1,98 +1,5 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s -// Check that we place local_alloc, local_store (optional) and local_load right after definition of their operands -// in cases where local_alloc is in the loop but it's operand is not. -// This is useful for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers -// throughout the computation. - -// CHECK-LABEL: hoist_q_out_of_the_loop -// CHECK: %[[TRUNCF:.+]] = arith.truncf -// CHECK-NEXT: %[[ALLOC:.+]] = ttg.local_alloc %[[TRUNCF]] -// CHECK-NEXT: ttg.local_load %[[ALLOC]] -// CHECK: scf.for -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> -#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { - tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant 1.44269502 : f32 - %c128_i32 = arith.constant 128 : i32 - %c128_i64 = arith.constant 128 : i64 - %c0_i64 = arith.constant 0 : i64 - %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> - %1 = tt.get_program_id y : i32 - %2 = arith.muli %1, %arg7 : i32 - %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 - %12 = tt.splat %3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> - %41 = tt.load %12 : tensor<256x128x!tt.ptr, #blocked1> - %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> - %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> - %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> - %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> - %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { - %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> - %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> - %75 = ttg.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !ttg.memdesc<256x128xf16, #shared, #smem> - %76 = ttg.local_load %75 : !ttg.memdesc<256x128xf16, #shared, #smem> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> - %77 = ttg.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem> - %78 = ttg.local_load %77 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> - %107 = arith.addi %arg26, %c128_i64 : i64 - scf.yield %107 : i64 - } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} - tt.return - } -} - - -// ----- -// Check that reordering described in hoist_q_out_of_the_loop is not done in the case where both -// local_alloc and it's src tensor defining op are in the loop. -// CHECK-LABEL: no_hoist_q_type_reordering -// CHECK: scf.for -// CHECK: %[[TRUNCF:.+]] = arith.truncf -// CHECK-NEXT: arith.constant -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> -#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { - tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant 1.44269502 : f32 - %c128_i32 = arith.constant 128 : i32 - %c128_i64 = arith.constant 128 : i64 - %c0_i64 = arith.constant 0 : i64 - %1 = tt.get_program_id y : i32 - %2 = arith.muli %1, %arg7 : i32 - %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 - %12 = tt.splat %3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> - %41 = tt.load %12 : tensor<256x128x!tt.ptr, #blocked1> - %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> - %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> - %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> - %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { - %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> - %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> - %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> - %75 = ttg.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !ttg.memdesc<256x128xf16, #shared, #smem> - %76 = ttg.local_load %75 : !ttg.memdesc<256x128xf16, #shared, #smem> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> - %77 = ttg.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem> - %78 = ttg.local_load %77 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> - %107 = arith.addi %arg26, %c128_i64 : i64 - scf.yield %107 : i64 - } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} - tt.return - } -} - -// ----- #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> #shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 75ff3ef29d2b..8a6d2b73f4fe 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -206,6 +206,7 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_remove_layout_conversions(pm) amd.passes.ttgpuir.add_optimize_epilogue(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) + amd.passes.ttgpuir.add_hoist_layout_conversions(pm) global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0")) local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0")) diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index 70fab0b00723..f5311723d1b1 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -25,6 +25,8 @@ std::unique_ptr createTritonAMDGPUVerifier(); std::unique_ptr createTritonAMDGPUOptimizeEpiloguePass(); +std::unique_ptr createTritonAMDGPUHoistLayoutConversionsPass(); + std::unique_ptr createTritonAMDGPUCanonicalizePointersPass(); std::unique_ptr createTritonAMDGPUConvertToBufferOpsPass( diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 6a58a31ad75b..d058504e4d8b 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -65,6 +65,18 @@ def TritonAMDGPUOptimizeEpilogue : Pass<"tritonamdgpu-optimize-epilogue", "mlir: } +def TritonAMDGPUHoistLayoutConversions : Pass<"tritonamdgpu-hoist-layout-conversions", "mlir::ModuleOp"> { + let summary = "Hoist layout conversions out of the loop"; + + let description = [{ + }]; + + let constructor = "mlir::createTritonAMDGPUHoistLayoutConversionsPass()"; + + let dependentDialects = []; + +} + def TritonAMDGPUCanonicalizePointers : Pass<"tritonamdgpu-canonicalize-pointers", "mlir::triton::FuncOp"> { let summary = "Canonicalize pointers: rewrite pointers passed to load/store operation as a `` pair."; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index 5478d9b4d780..3719d0c383d7 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_triton_library(TritonAMDGPUTransforms CanonicalizePointers.cpp ConvertToBufferOps.cpp OptimizeEpilogue.cpp + HoistLayoutConversions.cpp ReorderInstructions.cpp StreamPipeline.cpp MfmaGroup.cpp diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp new file mode 100644 index 000000000000..6b81ed202f66 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp @@ -0,0 +1,63 @@ +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +namespace ttg = mlir::triton::gpu; + +// Hoist convert_layout out of the loop if the src is defined out of the loop. +// This is a heuristic driven by optimizing fused attention kernels, in which +// we want to load Q tensor and keep it in register, instead of loading it +// (neither from global or shared memory) at every iteration of the loop. +static void hoistCvtDotOpOutOfLoop(ttg::ConvertLayoutOp cvtOp) { + // Check the dst of cvt has dotOperand layout + RankedTensorType rtType = dyn_cast(cvtOp.getType()); + if (!rtType) + return; + Attribute encoding = rtType.getEncoding(); + if (!encoding) + return; + if (!isa(encoding)) + return; + // Check the src of cvt is defined out of the loop + auto srcDefOp = cvtOp.getSrc().getDefiningOp(); + if (srcDefOp) { + scf::ForOp parentForOp = cvtOp->getParentOfType(); + if (parentForOp && !parentForOp->isAncestor(srcDefOp)) { + cvtOp->moveAfter(srcDefOp); + } + } +} + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +namespace { +struct TritonAMDGPUHoistLayoutConversionsPass + : public TritonAMDGPUHoistLayoutConversionsBase< + TritonAMDGPUHoistLayoutConversionsPass> { + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + for (auto funcOp : m.getOps()) { + funcOp.walk([&](ttg::ConvertLayoutOp cvtOp) -> void { + hoistCvtDotOpOutOfLoop(cvtOp); + }); + } + } +}; +} // namespace + +std::unique_ptr mlir::createTritonAMDGPUHoistLayoutConversionsPass() { + return std::make_unique(); +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index f1c84a9cf5bb..23f1a5c21ed4 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -131,77 +131,6 @@ static void sinkDotConversion(triton::FuncOp funcOp) { kv.first->moveBefore(kv.second); } -// Adjust the placement of shared memory writes and reads to immediately follow -// the definition of their operands in case where shared memory write is in the -// loop but its operand is not. -// -// This is a heuristic driven by optimizing fused attention by hoisting Q tensor -// shared memory read/write operations outside of the loop, as Q is a loop -// invariant and can be loaded once before entering the loop. But it should be -// generally applicable. -// -// There are two possible patterns for this adjustment depending on whether the -// write to shared memory is performed using an optional `local_alloc` argument -// or a `local_store` instruction. -// -// 1) %1 = some_op ... (typically a load or an operation that scales the tensor -// after loading) -// %2 = local_alloc %1 -// %3 = local_load %2 -// -// 2) %1 = some_op ... -// %2 = local_alloc -// %3 = local_store %1, %2 -// %4 = local_load %2 -static void hoistLocalLoad(triton::FuncOp funcOp) { - funcOp.walk([&](ttg::LocalLoadOp localLoad) { - auto localAlloc = localLoad.getSrc().getDefiningOp(); - if (!localAlloc) - return; - - // Case when localAlloc has operands - if (localAlloc->getNumOperands() == 1) { - if (!localAlloc->hasOneUse()) - return; - - auto srcTensorOp = localAlloc.getSrc().getDefiningOp(); - // Check if localAlloc is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) - return; - - localAlloc->moveAfter(srcTensorOp); - localLoad->moveAfter(localAlloc); - return; - } - - // Case when localAlloc has no operands - assert(localAlloc->getNumOperands() < 1); - auto allocVal = localAlloc->getResult(0); - - // Check if the localAlloc has exactly two uses (localStore and localLoad) - int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); - if (numUses != 2) - return; - - // localStore comes before localLoad in block. - Operation *localStore = getFirstUseInSameBlock(localAlloc); - if (!isa(localStore)) - return; - - auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); - // Check if localStore is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { - return; - } - - localAlloc->moveAfter(srcTensorOp); - localStore->moveAfter(localAlloc); - localLoad->moveAfter(localStore); - }); -} - // Sink conversion after the last dealloc but before the first use in its block. // This helps to avoid unnecessary shared memory allocation. static void moveDownCoversion(triton::FuncOp funcOp) { @@ -409,8 +338,6 @@ struct TritonAMDGPUReorderInstructionsPass void runOnOperation() override { ModuleOp m = getOperation(); for (auto funcOp : m.getOps()) { - hoistLocalLoad(funcOp); - sinkDotConversion(funcOp); moveDownCoversion(funcOp); diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 376a700abd1c..aac2f16f4dee 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -66,6 +66,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { const std::string, int, int); ADD_PASS_WRAPPER_0("add_optimize_epilogue", mlir::createTritonAMDGPUOptimizeEpiloguePass); + ADD_PASS_WRAPPER_0("add_hoist_layout_conversions", + mlir::createTritonAMDGPUHoistLayoutConversionsPass); m.def("add_canonicalize_pointers", [](mlir::PassManager &pm) { pm.addNestedPass( mlir::createTritonAMDGPUCanonicalizePointersPass()); From 9de2cbb8341deaeeea97c8164175518263e3a2cf Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 26 Feb 2025 14:59:35 +0000 Subject: [PATCH 2/3] Move the pass into a FuncOp scope --- .../include/TritonAMDGPUTransforms/Passes.td | 12 +++++++++--- .../HoistLayoutConversions.cpp | 17 ++++++++++------- third_party/amd/python/triton_amd.cc | 6 ++++-- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index d058504e4d8b..37b257c8b4b8 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -65,16 +65,22 @@ def TritonAMDGPUOptimizeEpilogue : Pass<"tritonamdgpu-optimize-epilogue", "mlir: } -def TritonAMDGPUHoistLayoutConversions : Pass<"tritonamdgpu-hoist-layout-conversions", "mlir::ModuleOp"> { +def TritonAMDGPUHoistLayoutConversions : Pass<"tritonamdgpu-hoist-layout-conversions", "mlir::triton::FuncOp"> { let summary = "Hoist layout conversions out of the loop"; let description = [{ + This pass tries to hoist a convert_layout op out of the loop if 1) its dst is a tensor + of dotOperand layout, and 2) its src is defined out of the loop. + The rational is as follows: + 1. When the defining op of the src is out of the loop, it means the src is loop-invariant. + Then we can potentially hoist this convert_layout op, since it's also loop-invariant. + 2. The drawback of this LICM is higher register pressure. However, on AMD GPUs, we have + a larger register file but smaller shared memory. It's beneficial to keep loop-invariant + variables in registers rather than loading them from shared memory in the loop. }]; let constructor = "mlir::createTritonAMDGPUHoistLayoutConversionsPass()"; - let dependentDialects = []; - } def TritonAMDGPUCanonicalizePointers : Pass<"tritonamdgpu-canonicalize-pointers", "mlir::triton::FuncOp"> { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp index 6b81ed202f66..f2c3a1a69e73 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp @@ -11,6 +11,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Utility.h" using namespace mlir; +namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; // Hoist convert_layout out of the loop if the src is defined out of the loop. @@ -46,14 +47,16 @@ struct TritonAMDGPUHoistLayoutConversionsPass TritonAMDGPUHoistLayoutConversionsPass> { void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp m = getOperation(); + tt::FuncOp funcOp = getOperation(); - for (auto funcOp : m.getOps()) { - funcOp.walk([&](ttg::ConvertLayoutOp cvtOp) -> void { - hoistCvtDotOpOutOfLoop(cvtOp); - }); - } + SmallVector cvtOps; + funcOp.walk([&](Operation *op) { + if (auto cvtOp = dyn_cast(op)) + cvtOps.push_back(cvtOp); + }); + + for (auto cvtOp : cvtOps) + hoistCvtDotOpOutOfLoop(cvtOp); } }; } // namespace diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index aac2f16f4dee..551da3f33c74 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -66,8 +66,10 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { const std::string, int, int); ADD_PASS_WRAPPER_0("add_optimize_epilogue", mlir::createTritonAMDGPUOptimizeEpiloguePass); - ADD_PASS_WRAPPER_0("add_hoist_layout_conversions", - mlir::createTritonAMDGPUHoistLayoutConversionsPass); + m.def("add_hoist_layout_conversions", [](mlir::PassManager &pm) { + pm.addNestedPass( + mlir::createTritonAMDGPUHoistLayoutConversionsPass()); + }); m.def("add_canonicalize_pointers", [](mlir::PassManager &pm) { pm.addNestedPass( mlir::createTritonAMDGPUCanonicalizePointersPass()); From 2e3af7d6285adf1a94e2f9e975c511e5356052a1 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 26 Feb 2025 15:39:12 +0000 Subject: [PATCH 3/3] Addressed review comments and added lit tests --- test/TritonGPU/amd/amd-hoist-cvtToDotOp.mlir | 86 +++++++++++++++++++ .../HoistLayoutConversions.cpp | 5 +- 2 files changed, 87 insertions(+), 4 deletions(-) create mode 100644 test/TritonGPU/amd/amd-hoist-cvtToDotOp.mlir diff --git a/test/TritonGPU/amd/amd-hoist-cvtToDotOp.mlir b/test/TritonGPU/amd/amd-hoist-cvtToDotOp.mlir new file mode 100644 index 000000000000..f8eba39a9f47 --- /dev/null +++ b/test/TritonGPU/amd/amd-hoist-cvtToDotOp.mlir @@ -0,0 +1,86 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-hoist-layout-conversions | FileCheck %s + +// Hoist convert_layout out of the loop since the defining op of the src is out of the loop + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +// CHECK-LABEL: hoist_cvtToDotOp +// CHECK: %[[AF16:.*]] = arith.truncf +// CHECK-NEXT: %[[opA:.*]] = ttg.convert_layout %[[AF16]] +// CHECK-NEXT: scf.for +// CHECK: tt.dot %[[opA]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @hoist_cvtToDotOp(%opA: tensor<256x128xf32, #blocked>, %opB: tensor<128x256xf16, #dotOp1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0 = arith.truncf %opA : tensor<256x128xf32, #blocked> to tensor<256x128xf16, #blocked> + %1:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %2 = ttg.convert_layout %0 : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #dotOp0> + %3 = tt.dot %2, %opB, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %1#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + + +// ----- + +// Keep convert_layout inside the loop since the defining op of the src is inside the loop + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +// CHECK-LABEL: defOp_in_loop +// CHECK: scf.for +// CHECK: %[[AF16:.*]] = arith.truncf +// CHECK-NEXT: %[[opA:.*]] = ttg.convert_layout %[[AF16]] +// CHECK: tt.dot %[[opA]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @defOp_in_loop(%opA: tensor<256x128xf32, #blocked>, %opB: tensor<128x256xf16, #dotOp1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %1:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %0 = arith.truncf %opA : tensor<256x128xf32, #blocked> to tensor<256x128xf16, #blocked> + %2 = ttg.convert_layout %0 : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #dotOp0> + %3 = tt.dot %2, %opB, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %1#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + + +// ----- + +// Keep convert_layout inside the loop since the defining op is a block argument of the loop + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +// CHECK-LABEL: defOp_blockArg +// CHECK: scf.for +// CHECK-NEXT: %[[opA:.*]] = ttg.convert_layout +// CHECK: tt.dot %[[opA]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @defOp_blockArg(%opA: tensor<256x128xf16, #blocked>, %opB: tensor<128x256xf16, #dotOp1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %1:2 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst, %arg2 = %opA) -> (tensor<256x256xf32, #mma>, tensor<256x128xf16, #blocked>) : i32 { + %2 = ttg.convert_layout %arg2 : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #dotOp0> + %3 = tt.dot %2, %opB, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + scf.yield %3, %arg2 : tensor<256x256xf32, #mma>, tensor<256x128xf16, #blocked> + } + tt.store %C_ptr, %1#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp index f2c3a1a69e73..416ee581d1b5 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp @@ -50,10 +50,7 @@ struct TritonAMDGPUHoistLayoutConversionsPass tt::FuncOp funcOp = getOperation(); SmallVector cvtOps; - funcOp.walk([&](Operation *op) { - if (auto cvtOp = dyn_cast(op)) - cvtOps.push_back(cvtOp); - }); + funcOp.walk([&](ttg::ConvertLayoutOp cvtOp) { cvtOps.push_back(cvtOp); }); for (auto cvtOp : cvtOps) hoistCvtDotOpOutOfLoop(cvtOp);