Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
// TritonAMDGPUTransforms passes
mlir::registerTritonAMDGPUAccelerateMatmul();
mlir::registerTritonAMDGPUOptimizeEpilogue();
mlir::registerTritonAMDGPUHoistLayoutConversions();
mlir::registerTritonAMDGPUReorderInstructions();
mlir::registerTritonAMDGPUBlockPingpong();
mlir::registerTritonAMDGPUStreamPipeline();
Expand Down
86 changes: 86 additions & 0 deletions test/TritonGPU/amd/amd-hoist-cvtToDotOp.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// RUN: triton-opt %s -split-input-file -tritonamdgpu-hoist-layout-conversions | FileCheck %s

// Hoist convert_layout out of the loop since the defining op of the src is out of the loop

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
// CHECK-LABEL: hoist_cvtToDotOp
// CHECK: %[[AF16:.*]] = arith.truncf
// CHECK-NEXT: %[[opA:.*]] = ttg.convert_layout %[[AF16]]
// CHECK-NEXT: scf.for
// CHECK: tt.dot %[[opA]]
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
tt.func public @hoist_cvtToDotOp(%opA: tensor<256x128xf32, #blocked>, %opB: tensor<128x256xf16, #dotOp1>, %C_ptr: tensor<256x256x!tt.ptr<f32>, #mma>) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%0 = arith.truncf %opA : tensor<256x128xf32, #blocked> to tensor<256x128xf16, #blocked>
%1:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
%2 = ttg.convert_layout %0 : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #dotOp0>
%3 = tt.dot %2, %opB, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
scf.yield %3 : tensor<256x256xf32, #mma>
}
tt.store %C_ptr, %1#0: tensor<256x256x!tt.ptr<f32>, #mma>
tt.return
}
}


// -----

// Keep convert_layout inside the loop since the defining op of the src is inside the loop

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
// CHECK-LABEL: defOp_in_loop
// CHECK: scf.for
// CHECK: %[[AF16:.*]] = arith.truncf
// CHECK-NEXT: %[[opA:.*]] = ttg.convert_layout %[[AF16]]
// CHECK: tt.dot %[[opA]]
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
tt.func public @defOp_in_loop(%opA: tensor<256x128xf32, #blocked>, %opB: tensor<128x256xf16, #dotOp1>, %C_ptr: tensor<256x256x!tt.ptr<f32>, #mma>) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%1:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
%0 = arith.truncf %opA : tensor<256x128xf32, #blocked> to tensor<256x128xf16, #blocked>
%2 = ttg.convert_layout %0 : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #dotOp0>
%3 = tt.dot %2, %opB, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
scf.yield %3 : tensor<256x256xf32, #mma>
}
tt.store %C_ptr, %1#0: tensor<256x256x!tt.ptr<f32>, #mma>
tt.return
}
}


// -----

// Keep convert_layout inside the loop since the defining op is a block argument of the loop

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
// CHECK-LABEL: defOp_blockArg
// CHECK: scf.for
// CHECK-NEXT: %[[opA:.*]] = ttg.convert_layout
// CHECK: tt.dot %[[opA]]
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
tt.func public @defOp_blockArg(%opA: tensor<256x128xf16, #blocked>, %opB: tensor<128x256xf16, #dotOp1>, %C_ptr: tensor<256x256x!tt.ptr<f32>, #mma>) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%1:2 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst, %arg2 = %opA) -> (tensor<256x256xf32, #mma>, tensor<256x128xf16, #blocked>) : i32 {
%2 = ttg.convert_layout %arg2 : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #dotOp0>
%3 = tt.dot %2, %opB, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
scf.yield %3, %arg2 : tensor<256x256xf32, #mma>, tensor<256x128xf16, #blocked>
}
tt.store %C_ptr, %1#0: tensor<256x256x!tt.ptr<f32>, #mma>
tt.return
}
}
93 changes: 0 additions & 93 deletions test/TritonGPU/amd/amd-reorder-instructions.mlir
Original file line number Diff line number Diff line change
@@ -1,98 +1,5 @@
// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s

// Check that we place local_alloc, local_store (optional) and local_load right after definition of their operands
// in cases where local_alloc is in the loop but it's operand is not.
// This is useful for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers
// throughout the computation.

// CHECK-LABEL: hoist_q_out_of_the_loop
// CHECK: %[[TRUNCF:.+]] = arith.truncf
// CHECK-NEXT: %[[ALLOC:.+]] = ttg.local_alloc %[[TRUNCF]]
// CHECK-NEXT: ttg.local_load %[[ALLOC]]
// CHECK: scf.for
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>
#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 1.44269502 : f32
%c128_i32 = arith.constant 128 : i32
%c128_i64 = arith.constant 128 : i64
%c0_i64 = arith.constant 0 : i64
%cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma>
%1 = tt.get_program_id y : i32
%2 = arith.muli %1, %arg7 : i32
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32
%12 = tt.splat %3 : !tt.ptr<f16> -> tensor<256x128x!tt.ptr<f16>, #blocked1>
%41 = tt.load %12 : tensor<256x128x!tt.ptr<f16>, #blocked1>
%42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1>
%43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1>
%44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1>
%45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
%54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 {
%73 = tt.splat %3 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked2>
%74 = tt.load %73 : tensor<128x128x!tt.ptr<f16>, #blocked2>
%75 = ttg.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !ttg.memdesc<256x128xf16, #shared, #smem>
%76 = ttg.local_load %75 : !ttg.memdesc<256x128xf16, #shared, #smem> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
%77 = ttg.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem>
%78 = ttg.local_load %77 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
%79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma>
%107 = arith.addi %arg26, %c128_i64 : i64
scf.yield %107 : i64
} {tt.divisibility_arg1 = dense<128> : tensor<1xi32>}
tt.return
}
}


// -----
// Check that reordering described in hoist_q_out_of_the_loop is not done in the case where both
// local_alloc and it's src tensor defining op are in the loop.
// CHECK-LABEL: no_hoist_q_type_reordering
// CHECK: scf.for
// CHECK: %[[TRUNCF:.+]] = arith.truncf
// CHECK-NEXT: arith.constant
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>
#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 1.44269502 : f32
%c128_i32 = arith.constant 128 : i32
%c128_i64 = arith.constant 128 : i64
%c0_i64 = arith.constant 0 : i64
%1 = tt.get_program_id y : i32
%2 = arith.muli %1, %arg7 : i32
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32
%12 = tt.splat %3 : !tt.ptr<f16> -> tensor<256x128x!tt.ptr<f16>, #blocked1>
%41 = tt.load %12 : tensor<256x128x!tt.ptr<f16>, #blocked1>
%42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1>
%43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1>
%44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1>
%54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 {
%45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma>
%73 = tt.splat %3 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked2>
%74 = tt.load %73 : tensor<128x128x!tt.ptr<f16>, #blocked2>
%75 = ttg.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !ttg.memdesc<256x128xf16, #shared, #smem>
%76 = ttg.local_load %75 : !ttg.memdesc<256x128xf16, #shared, #smem> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
%77 = ttg.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem>
%78 = ttg.local_load %77 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
%79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma>
%107 = arith.addi %arg26, %c128_i64 : i64
scf.yield %107 : i64
} {tt.divisibility_arg1 = dense<128> : tensor<1xi32>}
tt.return
}
}

// -----
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
Expand Down
1 change: 1 addition & 0 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def make_ttgir(mod, metadata, options):
passes.ttgpuir.add_remove_layout_conversions(pm)
amd.passes.ttgpuir.add_optimize_epilogue(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
amd.passes.ttgpuir.add_hoist_layout_conversions(pm)

global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0"))
local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0"))
Expand Down
2 changes: 2 additions & 0 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ std::unique_ptr<Pass> createTritonAMDGPUVerifier();

std::unique_ptr<Pass> createTritonAMDGPUOptimizeEpiloguePass();

std::unique_ptr<Pass> createTritonAMDGPUHoistLayoutConversionsPass();

std::unique_ptr<Pass> createTritonAMDGPUCanonicalizePointersPass();

std::unique_ptr<Pass> createTritonAMDGPUConvertToBufferOpsPass(
Expand Down
18 changes: 18 additions & 0 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ def TritonAMDGPUOptimizeEpilogue : Pass<"tritonamdgpu-optimize-epilogue", "mlir:

}

def TritonAMDGPUHoistLayoutConversions : Pass<"tritonamdgpu-hoist-layout-conversions", "mlir::triton::FuncOp"> {
let summary = "Hoist layout conversions out of the loop";

let description = [{
Comment thread
zhanglx13 marked this conversation as resolved.
This pass tries to hoist a convert_layout op out of the loop if 1) its dst is a tensor
of dotOperand layout, and 2) its src is defined out of the loop.
The rational is as follows:
1. When the defining op of the src is out of the loop, it means the src is loop-invariant.
Then we can potentially hoist this convert_layout op, since it's also loop-invariant.
2. The drawback of this LICM is higher register pressure. However, on AMD GPUs, we have
a larger register file but smaller shared memory. It's beneficial to keep loop-invariant
variables in registers rather than loading them from shared memory in the loop.
}];

let constructor = "mlir::createTritonAMDGPUHoistLayoutConversionsPass()";

}

def TritonAMDGPUCanonicalizePointers : Pass<"tritonamdgpu-canonicalize-pointers", "mlir::triton::FuncOp"> {
let summary = "Canonicalize pointers: rewrite pointers passed to load/store operation as a `<basePtr, offset>` pair.";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_triton_library(TritonAMDGPUTransforms
CanonicalizePointers.cpp
ConvertToBufferOps.cpp
OptimizeEpilogue.cpp
HoistLayoutConversions.cpp
ReorderInstructions.cpp
StreamPipeline.cpp
MfmaGroup.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include "TritonAMDGPUTransforms/Passes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;

// Hoist convert_layout out of the loop if the src is defined out of the loop.
// This is a heuristic driven by optimizing fused attention kernels, in which
// we want to load Q tensor and keep it in register, instead of loading it
// (neither from global or shared memory) at every iteration of the loop.
static void hoistCvtDotOpOutOfLoop(ttg::ConvertLayoutOp cvtOp) {
Comment thread
zhanglx13 marked this conversation as resolved.
// Check the dst of cvt has dotOperand layout
RankedTensorType rtType = dyn_cast<RankedTensorType>(cvtOp.getType());
if (!rtType)
return;
Attribute encoding = rtType.getEncoding();
if (!encoding)
return;
if (!isa<ttg::DotOperandEncodingAttr>(encoding))
return;
// Check the src of cvt is defined out of the loop
auto srcDefOp = cvtOp.getSrc().getDefiningOp();
if (srcDefOp) {
scf::ForOp parentForOp = cvtOp->getParentOfType<scf::ForOp>();
if (parentForOp && !parentForOp->isAncestor(srcDefOp)) {
cvtOp->moveAfter(srcDefOp);
}
}
}

#define GEN_PASS_CLASSES
#include "TritonAMDGPUTransforms/Passes.h.inc"

namespace {
struct TritonAMDGPUHoistLayoutConversionsPass
: public TritonAMDGPUHoistLayoutConversionsBase<
TritonAMDGPUHoistLayoutConversionsPass> {

void runOnOperation() override {
tt::FuncOp funcOp = getOperation();

SmallVector<ttg::ConvertLayoutOp> cvtOps;
funcOp.walk([&](ttg::ConvertLayoutOp cvtOp) { cvtOps.push_back(cvtOp); });

for (auto cvtOp : cvtOps)
hoistCvtDotOpOutOfLoop(cvtOp);
}
};
} // namespace

std::unique_ptr<Pass> mlir::createTritonAMDGPUHoistLayoutConversionsPass() {
return std::make_unique<TritonAMDGPUHoistLayoutConversionsPass>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,77 +131,6 @@ static void sinkDotConversion(triton::FuncOp funcOp) {
kv.first->moveBefore(kv.second);
}

// Adjust the placement of shared memory writes and reads to immediately follow
// the definition of their operands in case where shared memory write is in the
// loop but its operand is not.
//
// This is a heuristic driven by optimizing fused attention by hoisting Q tensor
// shared memory read/write operations outside of the loop, as Q is a loop
// invariant and can be loaded once before entering the loop. But it should be
// generally applicable.
//
// There are two possible patterns for this adjustment depending on whether the
// write to shared memory is performed using an optional `local_alloc` argument
// or a `local_store` instruction.
//
// 1) %1 = some_op ... (typically a load or an operation that scales the tensor
// after loading)
// %2 = local_alloc %1
// %3 = local_load %2
//
// 2) %1 = some_op ...
// %2 = local_alloc
// %3 = local_store %1, %2
// %4 = local_load %2
static void hoistLocalLoad(triton::FuncOp funcOp) {
funcOp.walk([&](ttg::LocalLoadOp localLoad) {
auto localAlloc = localLoad.getSrc().getDefiningOp<ttg::LocalAllocOp>();
if (!localAlloc)
return;

// Case when localAlloc has operands
if (localAlloc->getNumOperands() == 1) {
if (!localAlloc->hasOneUse())
return;

auto srcTensorOp = localAlloc.getSrc().getDefiningOp();
// Check if localAlloc is in the loop but it's src tensor defining op is
// outside of it.
if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp))
return;

localAlloc->moveAfter(srcTensorOp);
localLoad->moveAfter(localAlloc);
return;
}

// Case when localAlloc has no operands
assert(localAlloc->getNumOperands() < 1);
auto allocVal = localAlloc->getResult(0);

// Check if the localAlloc has exactly two uses (localStore and localLoad)
int numUses = std::distance(allocVal.use_begin(), allocVal.use_end());
if (numUses != 2)
return;

// localStore comes before localLoad in block.
Operation *localStore = getFirstUseInSameBlock(localAlloc);
if (!isa<ttg::LocalStoreOp>(localStore))
return;

auto srcTensorOp = localStore->getOperand(0).getDefiningOp();
// Check if localStore is in the loop but it's src tensor defining op is
// outside of it.
if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) {
return;
}

localAlloc->moveAfter(srcTensorOp);
localStore->moveAfter(localAlloc);
localLoad->moveAfter(localStore);
});
}

// Sink conversion after the last dealloc but before the first use in its block.
// This helps to avoid unnecessary shared memory allocation.
static void moveDownCoversion(triton::FuncOp funcOp) {
Expand Down Expand Up @@ -409,8 +338,6 @@ struct TritonAMDGPUReorderInstructionsPass
void runOnOperation() override {
ModuleOp m = getOperation();
for (auto funcOp : m.getOps<triton::FuncOp>()) {
hoistLocalLoad(funcOp);

sinkDotConversion(funcOp);
moveDownCoversion(funcOp);

Expand Down
Loading