diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 648aa299a641..9e0e6677e0bd 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -1071,6 +1071,24 @@ bool isRematBeneficial(ConvertLayoutOp convertOp, const SetVector &slice, auto *user = use.getOwner(); if (user == convertOp || sliceOps.contains(user)) continue; + // For region branch ops, check whether the values they flow into are in + // the slice or unused instead. + if (isa(user)) + user = user->getParentOp(); + if (auto rbi = dyn_cast(user)) { + RegionBranchSuccessorMapping mapping; + rbi.getSuccessorOperandInputMapping(mapping); + auto it = mapping.find(&use); + if (it != mapping.end()) { + // We have found the values this use flows into, check if they are + // used outside the slice. + bool isSliceOnly = llvm::all_of(it->second, [&](Value v) { + return slice.contains(v) || v.use_empty(); + }); + if (isSliceOnly) + continue; + } + } nonSliceOnlyValues.insert(v); break; } @@ -1079,12 +1097,26 @@ bool isRematBeneficial(ConvertLayoutOp convertOp, const SetVector &slice, // Expand the set to all transitive operands in the slice. for (size_t i = 0; i < nonSliceOnlyValues.size(); ++i) { Value v = nonSliceOnlyValues[i]; - if (auto *op = v.getDefiningOp()) { - for (auto operand : op->getOperands()) - if (slice.contains(operand)) - nonSliceOnlyValues.insert(operand); + auto *op = v.getDefiningOp(); + // If the operand is a block argument, get the enclosing op. + op = op ? op : v.getParentBlock()->getParentOp(); + if (auto rbi = dyn_cast(op)) { + // Try to determine the operands that flow into this value, and mark them + // as being used outside the slice. + RegionBranchInverseSuccessorMapping mapping; + rbi.getSuccessorInputOperandMapping(mapping); + auto it = mapping.find(v); + if (it != mapping.end()) { + for (auto tiedOperand : it->second) + if (slice.contains(tiedOperand->get())) + nonSliceOnlyValues.insert(tiedOperand->get()); + continue; + } } - // TODO: Handle block arguments. + // In the general case, propagate to all operands of the op. + for (auto operand : op->getOperands()) + if (slice.contains(operand)) + nonSliceOnlyValues.insert(operand); } int64_t convertLayoutCost = diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index ca43e0e31a92..02b508a1fe7a 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -4303,3 +4303,173 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, tt.return %sum_cvt : tensor<2xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>> } } + +// ----- + +// CHECK-LABEL: remat_cycle_single_use +// CHECK-NOT: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remat_cycle_single_use(%arg0: !tt.ptr, %arg1: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<8x8xf32, #blocked> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr, #blocked> + %1 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %cst) -> (tensor<8x8xf32, #blocked>) : i32 { + %2 = tt.load %0 : tensor<8x8x!tt.ptr, #blocked> + %3 = math.exp %2 : tensor<8x8xf32, #blocked> + %4 = math.exp %3 : tensor<8x8xf32, #blocked> + %5 = math.exp %4 : tensor<8x8xf32, #blocked> + %6 = math.exp %5 : tensor<8x8xf32, #blocked> + %7 = math.exp %6 : tensor<8x8xf32, #blocked> + %8 = arith.addf %arg3, %7 : tensor<8x8xf32, #blocked> + %9 = ttg.convert_layout %8 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + "use"(%9) : (tensor<8x8xf32, #blocked1>) -> () + scf.yield %8 : tensor<8x8xf32, #blocked> + } + tt.return + } +} + +// ----- + +// CHECK-LABEL: remat_for_iter_arg +// CHECK-NOT: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remat_for_iter_arg(%arg0: !tt.ptr, %arg1: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %ptrs = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr, #blocked> + %0 = tt.load %ptrs : tensor<8x8x!tt.ptr, #blocked> + %1 = math.exp %0 : tensor<8x8xf32, #blocked> + %2 = math.exp %1 : tensor<8x8xf32, #blocked> + %3 = math.exp %2 : tensor<8x8xf32, #blocked> + %4 = math.exp %3 : tensor<8x8xf32, #blocked> + %5 = math.exp %4 : tensor<8x8xf32, #blocked> + %6 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %5) -> (tensor<8x8xf32, #blocked>) : i32 { + %7 = ttg.convert_layout %arg3 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + "use"(%7) : (tensor<8x8xf32, #blocked1>) -> () + scf.yield %arg3 : tensor<8x8xf32, #blocked> + } + tt.return + } +} + +// ----- + +// CHECK-LABEL: remat_if_yield +// CHECK-NOT: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remat_if_yield(%arg0: !tt.ptr, %cond: i1) -> tensor<8x8xf32, #blocked1> { + %0 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr, #blocked> + %1 = tt.load %0 : tensor<8x8x!tt.ptr, #blocked> + %2 = math.exp %1 : tensor<8x8xf32, #blocked> + %3 = math.exp %2 : tensor<8x8xf32, #blocked> + %4 = math.exp %3 : tensor<8x8xf32, #blocked> + %5 = math.exp %4 : tensor<8x8xf32, #blocked> + %6 = math.exp %5 : tensor<8x8xf32, #blocked> + %7 = scf.if %cond -> tensor<8x8xf32, #blocked> { + scf.yield %6 : tensor<8x8xf32, #blocked> + } else { + scf.yield %6 : tensor<8x8xf32, #blocked> + } + %8 = ttg.convert_layout %7 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + tt.return %8: tensor<8x8xf32, #blocked1> + } +} + +// ----- + +// CHECK-LABEL: remat_if_nested_yield +// CHECK-NOT: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remat_if_nested_yield(%arg0: !tt.ptr, %cond1: i1, %cond2: i1) -> tensor<8x8xf32, #blocked1> { + %0 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr, #blocked> + %load = tt.load %0 : tensor<8x8x!tt.ptr, #blocked> + %outer = scf.if %cond1 -> tensor<8x8xf32, #blocked> { + %inner = scf.if %cond2 -> tensor<8x8xf32, #blocked> { + %1 = math.exp %load : tensor<8x8xf32, #blocked> + %2 = math.exp %1 : tensor<8x8xf32, #blocked> + %3 = math.exp %2 : tensor<8x8xf32, #blocked> + %4 = math.exp %3 : tensor<8x8xf32, #blocked> + %5 = math.exp %4 : tensor<8x8xf32, #blocked> + scf.yield %5 : tensor<8x8xf32, #blocked> + } else { + scf.yield %load : tensor<8x8xf32, #blocked> + } + scf.yield %inner : tensor<8x8xf32, #blocked> + } else { + scf.yield %load : tensor<8x8xf32, #blocked> + } + %cvt = ttg.convert_layout %outer : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + tt.return %cvt : tensor<8x8xf32, #blocked1> + } +} + +// ----- + +// Test that when the result of an IfOp is used outside the slice being +// rematerialized, we are able to propagate that information back to the yield +// operands. This prevents eliminating the convert_layout, because the cost is +// too high. + +// CHECK-LABEL: remat_if_yield_in_branch_multi_use_negative +// CHECK: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remat_if_yield_in_branch_multi_use_negative(%arg0: !tt.ptr, %cond: i1) -> (tensor<8x8xf32, #blocked>, tensor<8x8xf32, #blocked1>) { + %0 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr, #blocked> + %load = tt.load %0 : tensor<8x8x!tt.ptr, #blocked> + %result = scf.if %cond -> tensor<8x8xf32, #blocked> { + %1 = math.exp %load : tensor<8x8xf32, #blocked> + %2 = math.exp %1 : tensor<8x8xf32, #blocked> + %3 = math.exp %2 : tensor<8x8xf32, #blocked> + %4 = math.exp %3 : tensor<8x8xf32, #blocked> + %5 = math.exp %4 : tensor<8x8xf32, #blocked> + scf.yield %5 : tensor<8x8xf32, #blocked> + } else { + scf.yield %load : tensor<8x8xf32, #blocked> + } + %cvt = ttg.convert_layout %result : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + tt.return %result, %cvt: tensor<8x8xf32, #blocked>, tensor<8x8xf32, #blocked1> + } +} + +// ----- + +// Test that when the block arg is used outside the slice being rematerialized, +// we are able to propagate that information back to the loop operands. This +// prevents eliminating the convert_layout, because the cost is too high. + +// CHECK-LABEL: remat_for_iter_arg_multi_use_negative +// CHECK: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remat_for_iter_arg_multi_use_negative(%arg0: !tt.ptr, %arg1: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %ptrs = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr, #blocked> + %0 = tt.load %ptrs : tensor<8x8x!tt.ptr, #blocked> + %1 = math.exp %0 : tensor<8x8xf32, #blocked> + %2 = math.exp %1 : tensor<8x8xf32, #blocked> + %3 = math.exp %2 : tensor<8x8xf32, #blocked> + %4 = math.exp %3 : tensor<8x8xf32, #blocked> + %5 = math.exp %4 : tensor<8x8xf32, #blocked> + %6 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %5) -> (tensor<8x8xf32, #blocked>) : i32 { + %7 = ttg.convert_layout %arg3 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + "use"(%7) : (tensor<8x8xf32, #blocked1>) -> () + "other_use"(%arg3) : (tensor<8x8xf32, #blocked>) -> () + scf.yield %arg3 : tensor<8x8xf32, #blocked> + } + tt.return + } +}