Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions test/TritonGPU/amd/amd-sched-2nd-load.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s

// Check the logic of sched-2nd-load optimizations
// The following tile sizes should apply the optimization
// 256x256x128
// 256x256x64
// The following tile sizes should NOT apply the optimization
// 256x64x128
// 256x256x32
// scf.for loop with two dots should not apply the optimization


#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
// Should apply: tile size 256x256x128 with single dot
// CHECK-LABEL: sink_2nd_load_256x256x128
// CHECK: %[[tileA:.*]] = tt.load
// CHECK-NEXT: local_load
// CHECK-NEXT: local_load
// CHECK-NEXT: %[[tileB:.*]] = tt.load
// CHECK-NEXT: tt.dot
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
// CHECK-NEXT: triton_gpu.local_store %[[tileB]]
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<128x256x!tt.ptr<f16>, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0>
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1>
%3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
%4 = tt.load %A_ptr : tensor<256x128x!tt.ptr<f16>, #blocked>
%5 = tt.load %B_ptr : tensor<128x256x!tt.ptr<f16>, #blocked1>
triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>
scf.yield %3 : tensor<256x256xf32, #mma>
}
tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr<f32>, #mma>
tt.return
}
}

// Should apply: tile size 256x256x64 with single dot
// CHECK-LABEL: sink_2nd_load_256x256x64
// CHECK: %[[tileA:.*]] = tt.load
// CHECK-NEXT: local_load
// CHECK-NEXT: local_load
// CHECK-NEXT: %[[tileB:.*]] = tt.load
// CHECK-NEXT: tt.dot
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
// CHECK-NEXT: triton_gpu.local_store %[[tileB]]
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<64x256x!tt.ptr<f16>, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0>
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1>
%3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
%4 = tt.load %A_ptr : tensor<256x64x!tt.ptr<f16>, #blocked>
%5 = tt.load %B_ptr : tensor<64x256x!tt.ptr<f16>, #blocked1>
triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>
scf.yield %3 : tensor<256x256xf32, #mma>
}
tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr<f32>, #mma>
tt.return
}
}

// Should NOT apply: tile size 256x64x128 with single dot
// CHECK-LABEL: sink_2nd_load_256x64x128
// CHECK: %[[tileA:.*]] = tt.load
// CHECK-NEXT: %[[tileB:.*]] = tt.load
// CHECK-NEXT: local_load
// CHECK-NEXT: local_load
// CHECK-NEXT: tt.dot
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
// CHECK-NEXT: triton_gpu.local_store %[[tileB]]
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<128x64x!tt.ptr<f16>, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma>
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 {
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0>
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1>
%3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma>
%4 = tt.load %A_ptr : tensor<256x128x!tt.ptr<f16>, #blocked>
%5 = tt.load %B_ptr : tensor<128x64x!tt.ptr<f16>, #blocked1>
triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>
scf.yield %3 : tensor<256x64xf32, #mma>
}
tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr<f32>, #mma>
tt.return
}
}

// Should NOT apply: tile size 256x256x32 with single dot
// CHECK-LABEL: sink_2nd_load_256x256x32
// CHECK: %[[tileA:.*]] = tt.load
// CHECK-NEXT: %[[tileB:.*]] = tt.load
// CHECK-NEXT: local_load
// CHECK-NEXT: local_load
// CHECK-NEXT: tt.dot
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
// CHECK-NEXT: triton_gpu.local_store %[[tileB]]
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<32x256x!tt.ptr<f16>, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0>
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1>
%3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
%4 = tt.load %A_ptr : tensor<256x32x!tt.ptr<f16>, #blocked>
%5 = tt.load %B_ptr : tensor<32x256x!tt.ptr<f16>, #blocked1>
triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>
scf.yield %3 : tensor<256x256xf32, #mma>
}
tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr<f32>, #mma>
tt.return
}
}

// Should NOT apply: tile size 128x128x128 with two dots
// CHECK-LABEL: sink_2nd_load_128x128x128_two_dot
// CHECK: %[[tileA:.*]] = tt.load
// CHECK-NEXT: %[[tileB:.*]] = tt.load
// CHECK-NEXT: local_load
// CHECK-NEXT: local_load
// CHECK-NEXT: tt.dot
// CHECK-NEXT: tt.dot
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
// CHECK-NEXT: triton_gpu.local_store %[[tileB]]
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @sink_2nd_load_128x128x128_two_dot(%A_ptr: tensor<128x128x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<128x128x!tt.ptr<f16>, #blocked1>, %C_ptr: tensor<128x128x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 {
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0>
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1>
%3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma>
%6 = tt.dot %1, %2, %3 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma>
%4 = tt.load %A_ptr : tensor<128x128x!tt.ptr<f16>, #blocked>
%5 = tt.load %B_ptr : tensor<128x128x!tt.ptr<f16>, #blocked1>
triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %5, %B_LDS : tensor<128x128xf16, #blocked1> -> !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
scf.yield %6 : tensor<128x128xf32, #mma>
}
tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr<f32>, #mma>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,78 @@ class TritonAMDGPUReorderInstructionsPass
dfgop->moveBefore(block, block->begin());
}
}

/**
* Sched-load optimization for matmul kernels with large tile sizes
* The basic idea of sched-load optimization is to sink the 2nd tt.load
* after local_load so that global_load instructions can be interleaved with
* mfma's. This can help hide the issue latency of global_load instructions
* and improve performance on MI300X.
*
* It's assumed that the IR before this optimization has the following
* structure:
* ```mlir
* scf.for ..
* {
* tileA = tt.load a_ptr
* tileB = tt.load b_ptr
* opA = local_load bufferA
* opB = local_load bufferB
* res = tt.dot opA, opB
* local_store tileA, bufferA
* local_store tileB, bufferB
* }
* ```
* After this optimization, the IR is transformed to
* ```mlir
* scf.for ..
* {
* tileA = tt.load a_ptr
* opA = local_load bufferA
* opB = local_load bufferB
* tileB = tt.load b_ptr <-- 2nd tt.load is sinked here
* res = tt.dot opA, opB
* local_store tileA, bufferA
* local_store tileB, bufferB
* }
* ```
* For now, we don't have a perfect hueristic about when should this
* optimization be applied. Therefore, we implement a simple hueristic that
* this is applied when the tile size of A and B are large enough, i.e.
* nonKDim >= 128 and kDim >= 64. And also this is only applied for typical
* matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We
* are experimenting how to better control instruction scheduling and enable
* such optimizations.
*/
m.walk([&](scf::ForOp forOp) -> void {
SetVector<Operation *> loadOps;
triton::DotOp dotOp;
int nDotOps = 0;
for (Operation &op : forOp) {
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
loadOps.insert(loadOp);
if (auto curOp = dyn_cast<triton::DotOp>(&op)) {
nDotOps++;
dotOp = curOp;
}
}
// Only apply the optimization when there are 2 load's and 1 dot in the
// loop
if (loadOps.size() != 2 || nDotOps != 1)
return;
// Only apply the optimization when tile size is large enough
// 1. nonKDim >= 128
// 2. kDim >= 64
auto ldAOp = dyn_cast<triton::LoadOp>(loadOps[0]);
auto tileAShape = cast<RankedTensorType>(ldAOp.getType()).getShape();
auto ldBOp = dyn_cast<triton::LoadOp>(loadOps[1]);
auto tileBShape = cast<RankedTensorType>(ldBOp.getType()).getShape();
if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 &&
tileBShape[1] >= 128))
return;
// move ldBOp right before tt.dot
loadOps[1]->moveBefore(dotOp);
});
}
};

Expand Down