diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 3e9b8a084058..686e5a24e8dd 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -1,28 +1,100 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s -// Check that we order load, local_alloc, local_store (optional) and local_load one after another. This is useful -// for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers +// Check that we place local_alloc, local_store (optional) and local_load right after definition of their operands +// in cases where local_alloc is in the loop but it's operand is not. +// This is useful for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers // throughout the computation. -// CHECK-LABEL: order_load_alloc_local_load -// CHECK: %[[LOAD:.+]] = tt.load -// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[LOAD]] -// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @order_load_alloc_local_load(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { - %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %10 = triton_gpu.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> - %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> - tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> + +// CHECK-LABEL: hoist_q_out_of_the_loop +// CHECK: %[[TRUNCF:.+]] = arith.truncf +// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[TRUNCF]] +// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] +// CHECK: scf.for +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %12 = tt.splat %3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %41 = tt.load %12 : tensor<256x128x!tt.ptr, #blocked1> + %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> + %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> + %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> + %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { + %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> + %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> + %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> + %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> + %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %107 = arith.addi %arg26, %c128_i64 : i64 + scf.yield %107 : i64 + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} tt.return } } + + +// ----- +// Check that reordering described in hoist_q_out_of_the_loop is not done in the case where both +// local_alloc and it's src tensor defining op are in the loop. +// CHECK-LABEL: no_hoist_q_type_reordering +// CHECK: scf.for +// CHECK: %[[TRUNCF:.+]] = arith.truncf +// CHECK-NEXT: arith.constant +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i64 = arith.constant 0 : i64 + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %12 = tt.splat %3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %41 = tt.load %12 : tensor<256x128x!tt.ptr, #blocked1> + %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> + %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> + %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> + %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { + %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> + %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> + %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> + %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> + %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> + %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %107 = arith.addi %arg26, %c128_i64 : i64 + scf.yield %107 : i64 + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} + tt.return + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> + // CHECK-LABEL: order_load_alloc_local_load_local_store // CHECK: %[[LOAD:.+]] = tt.load // CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 1f9be3a4287f..e122f15fd901 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -61,6 +61,14 @@ findEarlyInsertionPoint(Block *block, Operation *move) { return ipnt; } +// Check if the operation opInsideLoop is inside any scf::ForOp and +// opOutsideLoop is not inside the same loop. +bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, + mlir::Operation *opOutsideLoop) { + scf::ForOp parentForOp = opInsideLoop->getParentOfType(); + return parentForOp && !parentForOp->isAncestor(opOutsideLoop); +} + class TritonAMDGPUReorderInstructionsPass : public TritonAMDGPUReorderInstructionsBase< TritonAMDGPUReorderInstructionsPass> { @@ -101,19 +109,28 @@ class TritonAMDGPUReorderInstructionsPass kv.first->moveBefore(kv.second); opToMove.clear(); - // Move writing to LDS and reading from LDS right after the loading of a - // tensor from global memory. There are 2 possible patterns depending on - // whether writing to LDS is done using an optional local_alloc argument or - // a local_store instruction: + // Adjust the placement of LDS writes and reads to immediately follow the + // definition of their operands in case where LDS write is in the + // loop but it's operand is not. This is a heuristic for optimizing fused + // attention by hoisting Q tensor LDS read/write operations outside of the + // loop, as Q is a loop invariant and can be loaded once before entering the + // loop. + // There are two possible patterns for this adjustment depending on + // whether the write to LDS is performed using an optional `local_alloc` + // argument or a `local_store` instruction. + // + // clang-format off // - // 1) %1 = load %ptr + // 1) %1 = some_op ... (typically a load or an operation that scales the tensor after loading) // %2 = local_alloc %1 // %3 = local_load %2 // - // 2) %1 = load %ptr + // 2) %1 = some_op ... // %2 = local_alloc // %3 = local_store %1, %2 // %4 = local_load %2 + // + // clang-format on m.walk([&](ttg::LocalLoadOp localLoad) { auto localAlloc = localLoad.getSrc().getDefiningOp(); if (!localAlloc) @@ -123,10 +140,15 @@ class TritonAMDGPUReorderInstructionsPass if (localAlloc->getNumOperands() == 1) { if (!localAlloc->hasOneUse()) return; - auto loadOp = localAlloc->getOperand(0).getDefiningOp(); - if (!loadOp) + + auto srcTensorOp = localAlloc->getOperand(0).getDefiningOp(); + // Check if localAlloc is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) { return; - localAlloc->moveAfter(loadOp); + } + + localAlloc->moveAfter(srcTensorOp); localLoad->moveAfter(localAlloc); return; } @@ -145,10 +167,14 @@ class TritonAMDGPUReorderInstructionsPass if (!isa(localStore)) return; - auto loadOp = localStore->getOperand(0).getDefiningOp(); - if (!loadOp) + auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); + // Check if localStore is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { return; - localAlloc->moveAfter(loadOp); + } + + localAlloc->moveAfter(srcTensorOp); localStore->moveAfter(localAlloc); localLoad->moveAfter(localStore); });