diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 3e9b8a084058..686e5a24e8dd 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -1,28 +1,100 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s -// Check that we order load, local_alloc, local_store (optional) and local_load one after another. This is useful -// for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers +// Check that we place local_alloc, local_store (optional) and local_load right after definition of their operands +// in cases where local_alloc is in the loop but it's operand is not. +// This is useful for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers // throughout the computation. -// CHECK-LABEL: order_load_alloc_local_load -// CHECK: %[[LOAD:.+]] = tt.load -// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[LOAD]] -// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @order_load_alloc_local_load(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { - %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %10 = triton_gpu.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> - %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> - tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> + +// CHECK-LABEL: hoist_q_out_of_the_loop +// CHECK: %[[TRUNCF:.+]] = arith.truncf +// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[TRUNCF]] +// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] +// CHECK: scf.for +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %12 = tt.splat %3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %41 = tt.load %12 : tensor<256x128x!tt.ptr, #blocked1> + %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> + %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> + %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> + %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { + %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> + %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> + %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> + %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> + %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %107 = arith.addi %arg26, %c128_i64 : i64 + scf.yield %107 : i64 + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} tt.return } } + + +// ----- +// Check that reordering described in hoist_q_out_of_the_loop is not done in the case where both +// local_alloc and it's src tensor defining op are in the loop. +// CHECK-LABEL: no_hoist_q_type_reordering +// CHECK: scf.for +// CHECK: %[[TRUNCF:.+]] = arith.truncf +// CHECK-NEXT: arith.constant +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i64 = arith.constant 0 : i64 + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %12 = tt.splat %3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %41 = tt.load %12 : tensor<256x128x!tt.ptr, #blocked1> + %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> + %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> + %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> + %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { + %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> + %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> + %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> + %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> + %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> + %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %107 = arith.addi %arg26, %c128_i64 : i64 + scf.yield %107 : i64 + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} + tt.return + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> + // CHECK-LABEL: order_load_alloc_local_load_local_store // CHECK: %[[LOAD:.+]] = tt.load // CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir deleted file mode 100644 index bea937da60ee..000000000000 --- a/test/TritonGPU/amd/amd-sched-2nd-load.mlir +++ /dev/null @@ -1,165 +0,0 @@ -// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s - -// Check the logic of sched-2nd-load optimizations -// The following tile sizes should apply the optimization -// 256x256x128 -// 256x256x64 -// The following tile sizes should NOT apply the optimization -// 256x64x128 -// 256x256x32 -// scf.for loop with two dots should not apply the optimization - - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> -#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> -// Should apply: tile size 256x256x128 with single dot -// CHECK-LABEL: sink_2nd_load_256x256x128 -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x256xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> - tt.return - } -} - -// Should apply: tile size 256x256x64 with single dot -// CHECK-LABEL: sink_2nd_load_256x256x64 -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x256xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> - tt.return - } -} - -// Should NOT apply: tile size 256x64x128 with single dot -// CHECK-LABEL: sink_2nd_load_256x64x128 -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x64xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr, #mma> - tt.return - } -} - -// Should NOT apply: tile size 256x256x32 with single dot -// CHECK-LABEL: sink_2nd_load_256x256x32 -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %3 : tensor<256x256xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> - tt.return - } -} - -// Should NOT apply: tile size 128x128x128 with two dots -// CHECK-LABEL: sink_2nd_load_128x128x128_two_dot -// CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: %[[tileB:.*]] = tt.load -// CHECK-NEXT: local_load -// CHECK-NEXT: local_load -// CHECK-NEXT: tt.dot -// CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_128x128x128_two_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> - %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> - %6 = tt.dot %1, %2, %3 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> - %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> - triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<128x128xf16, #blocked1> -> !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %6 : tensor<128x128xf32, #mma> - } - tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr, #mma> - tt.return - } -} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 5c3e0ea4c8c5..e122f15fd901 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -61,6 +61,14 @@ findEarlyInsertionPoint(Block *block, Operation *move) { return ipnt; } +// Check if the operation opInsideLoop is inside any scf::ForOp and +// opOutsideLoop is not inside the same loop. +bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, + mlir::Operation *opOutsideLoop) { + scf::ForOp parentForOp = opInsideLoop->getParentOfType(); + return parentForOp && !parentForOp->isAncestor(opOutsideLoop); +} + class TritonAMDGPUReorderInstructionsPass : public TritonAMDGPUReorderInstructionsBase< TritonAMDGPUReorderInstructionsPass> { @@ -101,19 +109,28 @@ class TritonAMDGPUReorderInstructionsPass kv.first->moveBefore(kv.second); opToMove.clear(); - // Move writing to LDS and reading from LDS right after the loading of a - // tensor from global memory. There are 2 possible patterns depending on - // whether writing to LDS is done using an optional local_alloc argument or - // a local_store instruction: + // Adjust the placement of LDS writes and reads to immediately follow the + // definition of their operands in case where LDS write is in the + // loop but it's operand is not. This is a heuristic for optimizing fused + // attention by hoisting Q tensor LDS read/write operations outside of the + // loop, as Q is a loop invariant and can be loaded once before entering the + // loop. + // There are two possible patterns for this adjustment depending on + // whether the write to LDS is performed using an optional `local_alloc` + // argument or a `local_store` instruction. // - // 1) %1 = load %ptr + // clang-format off + // + // 1) %1 = some_op ... (typically a load or an operation that scales the tensor after loading) // %2 = local_alloc %1 // %3 = local_load %2 // - // 2) %1 = load %ptr + // 2) %1 = some_op ... // %2 = local_alloc // %3 = local_store %1, %2 // %4 = local_load %2 + // + // clang-format on m.walk([&](ttg::LocalLoadOp localLoad) { auto localAlloc = localLoad.getSrc().getDefiningOp(); if (!localAlloc) @@ -123,10 +140,15 @@ class TritonAMDGPUReorderInstructionsPass if (localAlloc->getNumOperands() == 1) { if (!localAlloc->hasOneUse()) return; - auto loadOp = localAlloc->getOperand(0).getDefiningOp(); - if (!loadOp) + + auto srcTensorOp = localAlloc->getOperand(0).getDefiningOp(); + // Check if localAlloc is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) { return; - localAlloc->moveAfter(loadOp); + } + + localAlloc->moveAfter(srcTensorOp); localLoad->moveAfter(localAlloc); return; } @@ -145,10 +167,14 @@ class TritonAMDGPUReorderInstructionsPass if (!isa(localStore)) return; - auto loadOp = localStore->getOperand(0).getDefiningOp(); - if (!loadOp) + auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); + // Check if localStore is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { return; - localAlloc->moveAfter(loadOp); + } + + localAlloc->moveAfter(srcTensorOp); localStore->moveAfter(localAlloc); localLoad->moveAfter(localStore); }); @@ -221,78 +247,6 @@ class TritonAMDGPUReorderInstructionsPass dfgop->moveBefore(block, block->begin()); } } - - /** - * Sched-load optimization for matmul kernels with large tile sizes - * The basic idea of sched-load optimization is to sink the 2nd tt.load - * after local_load so that global_load instructions can be interleaved with - * mfma's. This can help hide the issue latency of global_load instructions - * and improve performance on MI300X. - * - * It's assumed that the IR before this optimization has the following - * structure: - * ```mlir - * scf.for .. - * { - * tileA = tt.load a_ptr - * tileB = tt.load b_ptr - * opA = local_load bufferA - * opB = local_load bufferB - * res = tt.dot opA, opB - * local_store tileA, bufferA - * local_store tileB, bufferB - * } - * ``` - * After this optimization, the IR is transformed to - * ```mlir - * scf.for .. - * { - * tileA = tt.load a_ptr - * opA = local_load bufferA - * opB = local_load bufferB - * tileB = tt.load b_ptr <-- 2nd tt.load is sinked here - * res = tt.dot opA, opB - * local_store tileA, bufferA - * local_store tileB, bufferB - * } - * ``` - * For now, we don't have a perfect hueristic about when should this - * optimization be applied. Therefore, we implement a simple hueristic that - * this is applied when the tile size of A and B are large enough, i.e. - * nonKDim >= 128 and kDim >= 64. And also this is only applied for typical - * matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We - * are experimenting how to better control instruction scheduling and enable - * such optimizations. - */ - m.walk([&](scf::ForOp forOp) -> void { - SetVector loadOps; - triton::DotOp dotOp; - int nDotOps = 0; - for (Operation &op : forOp) { - if (auto loadOp = dyn_cast(&op)) - loadOps.insert(loadOp); - if (auto curOp = dyn_cast(&op)) { - nDotOps++; - dotOp = curOp; - } - } - // Only apply the optimization when there are 2 load's and 1 dot in the - // loop - if (loadOps.size() != 2 || nDotOps != 1) - return; - // Only apply the optimization when tile size is large enough - // 1. nonKDim >= 128 - // 2. kDim >= 64 - auto ldAOp = dyn_cast(loadOps[0]); - auto tileAShape = cast(ldAOp.getType()).getShape(); - auto ldBOp = dyn_cast(loadOps[1]); - auto tileBShape = cast(ldBOp.getType()).getShape(); - if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && - tileBShape[1] >= 128)) - return; - // move ldBOp right before tt.dot - loadOps[1]->moveBefore(dotOp); - }); } };