Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 91 additions & 19 deletions test/TritonGPU/amd/amd-reorder-instructions.mlir
Original file line number Diff line number Diff line change
@@ -1,28 +1,100 @@
// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s

// Check that we order load, local_alloc, local_store (optional) and local_load one after another. This is useful
// for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers
// Check that we place local_alloc, local_store (optional) and local_load right after definition of their operands
// in cases where local_alloc is in the loop but it's operand is not.
// This is useful for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers
// throughout the computation.
// CHECK-LABEL: order_load_alloc_local_load
// CHECK: %[[LOAD:.+]] = tt.load
// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[LOAD]]
// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]]
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @order_load_alloc_local_load(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked>) attributes {noinline = false} {
%9 = tt.load %arg0 : tensor<32x32x!tt.ptr<f32>, #blocked>
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
%10 = triton_gpu.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared>
%cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma>
%13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked>
tt.store %arg0, %13 : tensor<32x32x!tt.ptr<f32>, #blocked>

// CHECK-LABEL: hoist_q_out_of_the_loop
// CHECK: %[[TRUNCF:.+]] = arith.truncf
// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[TRUNCF]]
// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]]
// CHECK: scf.for
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 1.44269502 : f32
%c128_i32 = arith.constant 128 : i32
%c128_i64 = arith.constant 128 : i64
%c0_i64 = arith.constant 0 : i64
%cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma>
%1 = tt.get_program_id y : i32
%2 = arith.muli %1, %arg7 : i32
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32
%12 = tt.splat %3 : !tt.ptr<f16> -> tensor<256x128x!tt.ptr<f16>, #blocked1>
%41 = tt.load %12 : tensor<256x128x!tt.ptr<f16>, #blocked1>
%42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1>
%43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1>
%44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1>
%45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
%54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 {
%73 = tt.splat %3 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked2>
%74 = tt.load %73 : tensor<128x128x!tt.ptr<f16>, #blocked2>
%75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory>
%76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
%77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory>
%78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
%79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma>
%107 = arith.addi %arg26, %c128_i64 : i64
scf.yield %107 : i64
} {tt.divisibility_arg1 = dense<128> : tensor<1xi32>}
tt.return
}
}


// -----
// Check that reordering described in hoist_q_out_of_the_loop is not done in the case where both
// local_alloc and it's src tensor defining op are in the loop.
// CHECK-LABEL: no_hoist_q_type_reordering
// CHECK: scf.for
// CHECK: %[[TRUNCF:.+]] = arith.truncf
// CHECK-NEXT: arith.constant
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 1.44269502 : f32
%c128_i32 = arith.constant 128 : i32
%c128_i64 = arith.constant 128 : i64
%c0_i64 = arith.constant 0 : i64
%1 = tt.get_program_id y : i32
%2 = arith.muli %1, %arg7 : i32
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32
%12 = tt.splat %3 : !tt.ptr<f16> -> tensor<256x128x!tt.ptr<f16>, #blocked1>
%41 = tt.load %12 : tensor<256x128x!tt.ptr<f16>, #blocked1>
%42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1>
%43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1>
%44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1>
%54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 {
%45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma>
%73 = tt.splat %3 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked2>
%74 = tt.load %73 : tensor<128x128x!tt.ptr<f16>, #blocked2>
%75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory>
%76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
%77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory>
%78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
%79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma>
%107 = arith.addi %arg26, %c128_i64 : i64
scf.yield %107 : i64
} {tt.divisibility_arg1 = dense<128> : tensor<1xi32>}
tt.return
}
}

// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>

// CHECK-LABEL: order_load_alloc_local_load_local_store
// CHECK: %[[LOAD:.+]] = tt.load
// CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ findEarlyInsertionPoint(Block *block, Operation *move) {
return ipnt;
}

// Check if the operation opInsideLoop is inside any scf::ForOp and
// opOutsideLoop is not inside the same loop.
bool isCrossLoopBoundary(mlir::Operation *opInsideLoop,
mlir::Operation *opOutsideLoop) {
scf::ForOp parentForOp = opInsideLoop->getParentOfType<scf::ForOp>();
return parentForOp && !parentForOp->isAncestor(opOutsideLoop);
}

class TritonAMDGPUReorderInstructionsPass
: public TritonAMDGPUReorderInstructionsBase<
TritonAMDGPUReorderInstructionsPass> {
Expand Down Expand Up @@ -101,19 +109,28 @@ class TritonAMDGPUReorderInstructionsPass
kv.first->moveBefore(kv.second);
opToMove.clear();

// Move writing to LDS and reading from LDS right after the loading of a
// tensor from global memory. There are 2 possible patterns depending on
// whether writing to LDS is done using an optional local_alloc argument or
// a local_store instruction:
// Adjust the placement of LDS writes and reads to immediately follow the
// definition of their operands in case where LDS write is in the
// loop but it's operand is not. This is a heuristic for optimizing fused
// attention by hoisting Q tensor LDS read/write operations outside of the
// loop, as Q is a loop invariant and can be loaded once before entering the
// loop.
// There are two possible patterns for this adjustment depending on
// whether the write to LDS is performed using an optional `local_alloc`
// argument or a `local_store` instruction.
//
// clang-format off
//
// 1) %1 = load %ptr
// 1) %1 = some_op ... (typically a load or an operation that scales the tensor after loading)
// %2 = local_alloc %1
// %3 = local_load %2
//
// 2) %1 = load %ptr
// 2) %1 = some_op ...
// %2 = local_alloc
// %3 = local_store %1, %2
// %4 = local_load %2
//
// clang-format on
m.walk([&](ttg::LocalLoadOp localLoad) {
auto localAlloc = localLoad.getSrc().getDefiningOp<ttg::LocalAllocOp>();
if (!localAlloc)
Expand All @@ -123,10 +140,15 @@ class TritonAMDGPUReorderInstructionsPass
if (localAlloc->getNumOperands() == 1) {
if (!localAlloc->hasOneUse())
return;
auto loadOp = localAlloc->getOperand(0).getDefiningOp<tt::LoadOp>();
if (!loadOp)

auto srcTensorOp = localAlloc->getOperand(0).getDefiningOp();
// Check if localAlloc is in the loop but it's src tensor defining op is
// outside of it.
if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) {
return;
localAlloc->moveAfter(loadOp);
}

localAlloc->moveAfter(srcTensorOp);
localLoad->moveAfter(localAlloc);
return;
}
Expand All @@ -145,10 +167,14 @@ class TritonAMDGPUReorderInstructionsPass
if (!isa<ttg::LocalStoreOp>(localStore))
return;

auto loadOp = localStore->getOperand(0).getDefiningOp<tt::LoadOp>();
if (!loadOp)
auto srcTensorOp = localStore->getOperand(0).getDefiningOp();
// Check if localStore is in the loop but it's src tensor defining op is
// outside of it.
if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) {
return;
localAlloc->moveAfter(loadOp);
}

localAlloc->moveAfter(srcTensorOp);
localStore->moveAfter(localAlloc);
localLoad->moveAfter(localStore);
});
Expand Down