From e2c4ef96a5def1dbc1766f6857d2deddefcf35d6 Mon Sep 17 00:00:00 2001 From: Neil Dhar Date: Mon, 11 May 2026 08:21:29 -0700 Subject: [PATCH] Handle region control flow in remat cost calculation Determining whether an op is used outside of the slice being rematerialised is currently conservative. For instance, if an op is used by an `scf.yield`, we automatically treat that as a non-slice user, even if the value that the yield flows into is actually part of the slice or completely unused. To address this, trace through region control flow when deciding whether an op is single use. --- .../Transforms/RemoveLayoutConversions.cpp | 42 ++++- test/TritonGPU/combine.mlir | 170 ++++++++++++++++++ 2 files changed, 207 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 648aa299a641..9e0e6677e0bd 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -1071,6 +1071,24 @@ bool isRematBeneficial(ConvertLayoutOp convertOp, const SetVector &slice, auto *user = use.getOwner(); if (user == convertOp || sliceOps.contains(user)) continue; + // For region branch ops, check whether the values they flow into are in + // the slice or unused instead. + if (isa(user)) + user = user->getParentOp(); + if (auto rbi = dyn_cast(user)) { + RegionBranchSuccessorMapping mapping; + rbi.getSuccessorOperandInputMapping(mapping); + auto it = mapping.find(&use); + if (it != mapping.end()) { + // We have found the values this use flows into, check if they are + // used outside the slice. + bool isSliceOnly = llvm::all_of(it->second, [&](Value v) { + return slice.contains(v) || v.use_empty(); + }); + if (isSliceOnly) + continue; + } + } nonSliceOnlyValues.insert(v); break; } @@ -1079,12 +1097,26 @@ bool isRematBeneficial(ConvertLayoutOp convertOp, const SetVector &slice, // Expand the set to all transitive operands in the slice. for (size_t i = 0; i < nonSliceOnlyValues.size(); ++i) { Value v = nonSliceOnlyValues[i]; - if (auto *op = v.getDefiningOp()) { - for (auto operand : op->getOperands()) - if (slice.contains(operand)) - nonSliceOnlyValues.insert(operand); + auto *op = v.getDefiningOp(); + // If the operand is a block argument, get the enclosing op. + op = op ? op : v.getParentBlock()->getParentOp(); + if (auto rbi = dyn_cast(op)) { + // Try to determine the operands that flow into this value, and mark them + // as being used outside the slice. + RegionBranchInverseSuccessorMapping mapping; + rbi.getSuccessorInputOperandMapping(mapping); + auto it = mapping.find(v); + if (it != mapping.end()) { + for (auto tiedOperand : it->second) + if (slice.contains(tiedOperand->get())) + nonSliceOnlyValues.insert(tiedOperand->get()); + continue; + } } - // TODO: Handle block arguments. + // In the general case, propagate to all operands of the op. + for (auto operand : op->getOperands()) + if (slice.contains(operand)) + nonSliceOnlyValues.insert(operand); } int64_t convertLayoutCost = diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index ca43e0e31a92..02b508a1fe7a 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -4303,3 +4303,173 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, tt.return %sum_cvt : tensor<2xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>> } } + +// ----- + +// CHECK-LABEL: remat_cycle_single_use +// CHECK-NOT: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remat_cycle_single_use(%arg0: !tt.ptr, %arg1: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<8x8xf32, #blocked> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr, #blocked> + %1 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %cst) -> (tensor<8x8xf32, #blocked>) : i32 { + %2 = tt.load %0 : tensor<8x8x!tt.ptr, #blocked> + %3 = math.exp %2 : tensor<8x8xf32, #blocked> + %4 = math.exp %3 : tensor<8x8xf32, #blocked> + %5 = math.exp %4 : tensor<8x8xf32, #blocked> + %6 = math.exp %5 : tensor<8x8xf32, #blocked> + %7 = math.exp %6 : tensor<8x8xf32, #blocked> + %8 = arith.addf %arg3, %7 : tensor<8x8xf32, #blocked> + %9 = ttg.convert_layout %8 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + "use"(%9) : (tensor<8x8xf32, #blocked1>) -> () + scf.yield %8 : tensor<8x8xf32, #blocked> + } + tt.return + } +} + +// ----- + +// CHECK-LABEL: remat_for_iter_arg +// CHECK-NOT: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remat_for_iter_arg(%arg0: !tt.ptr, %arg1: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %ptrs = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr, #blocked> + %0 = tt.load %ptrs : tensor<8x8x!tt.ptr, #blocked> + %1 = math.exp %0 : tensor<8x8xf32, #blocked> + %2 = math.exp %1 : tensor<8x8xf32, #blocked> + %3 = math.exp %2 : tensor<8x8xf32, #blocked> + %4 = math.exp %3 : tensor<8x8xf32, #blocked> + %5 = math.exp %4 : tensor<8x8xf32, #blocked> + %6 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %5) -> (tensor<8x8xf32, #blocked>) : i32 { + %7 = ttg.convert_layout %arg3 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + "use"(%7) : (tensor<8x8xf32, #blocked1>) -> () + scf.yield %arg3 : tensor<8x8xf32, #blocked> + } + tt.return + } +} + +// ----- + +// CHECK-LABEL: remat_if_yield +// CHECK-NOT: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remat_if_yield(%arg0: !tt.ptr, %cond: i1) -> tensor<8x8xf32, #blocked1> { + %0 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr, #blocked> + %1 = tt.load %0 : tensor<8x8x!tt.ptr, #blocked> + %2 = math.exp %1 : tensor<8x8xf32, #blocked> + %3 = math.exp %2 : tensor<8x8xf32, #blocked> + %4 = math.exp %3 : tensor<8x8xf32, #blocked> + %5 = math.exp %4 : tensor<8x8xf32, #blocked> + %6 = math.exp %5 : tensor<8x8xf32, #blocked> + %7 = scf.if %cond -> tensor<8x8xf32, #blocked> { + scf.yield %6 : tensor<8x8xf32, #blocked> + } else { + scf.yield %6 : tensor<8x8xf32, #blocked> + } + %8 = ttg.convert_layout %7 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + tt.return %8: tensor<8x8xf32, #blocked1> + } +} + +// ----- + +// CHECK-LABEL: remat_if_nested_yield +// CHECK-NOT: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remat_if_nested_yield(%arg0: !tt.ptr, %cond1: i1, %cond2: i1) -> tensor<8x8xf32, #blocked1> { + %0 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr, #blocked> + %load = tt.load %0 : tensor<8x8x!tt.ptr, #blocked> + %outer = scf.if %cond1 -> tensor<8x8xf32, #blocked> { + %inner = scf.if %cond2 -> tensor<8x8xf32, #blocked> { + %1 = math.exp %load : tensor<8x8xf32, #blocked> + %2 = math.exp %1 : tensor<8x8xf32, #blocked> + %3 = math.exp %2 : tensor<8x8xf32, #blocked> + %4 = math.exp %3 : tensor<8x8xf32, #blocked> + %5 = math.exp %4 : tensor<8x8xf32, #blocked> + scf.yield %5 : tensor<8x8xf32, #blocked> + } else { + scf.yield %load : tensor<8x8xf32, #blocked> + } + scf.yield %inner : tensor<8x8xf32, #blocked> + } else { + scf.yield %load : tensor<8x8xf32, #blocked> + } + %cvt = ttg.convert_layout %outer : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + tt.return %cvt : tensor<8x8xf32, #blocked1> + } +} + +// ----- + +// Test that when the result of an IfOp is used outside the slice being +// rematerialized, we are able to propagate that information back to the yield +// operands. This prevents eliminating the convert_layout, because the cost is +// too high. + +// CHECK-LABEL: remat_if_yield_in_branch_multi_use_negative +// CHECK: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remat_if_yield_in_branch_multi_use_negative(%arg0: !tt.ptr, %cond: i1) -> (tensor<8x8xf32, #blocked>, tensor<8x8xf32, #blocked1>) { + %0 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr, #blocked> + %load = tt.load %0 : tensor<8x8x!tt.ptr, #blocked> + %result = scf.if %cond -> tensor<8x8xf32, #blocked> { + %1 = math.exp %load : tensor<8x8xf32, #blocked> + %2 = math.exp %1 : tensor<8x8xf32, #blocked> + %3 = math.exp %2 : tensor<8x8xf32, #blocked> + %4 = math.exp %3 : tensor<8x8xf32, #blocked> + %5 = math.exp %4 : tensor<8x8xf32, #blocked> + scf.yield %5 : tensor<8x8xf32, #blocked> + } else { + scf.yield %load : tensor<8x8xf32, #blocked> + } + %cvt = ttg.convert_layout %result : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + tt.return %result, %cvt: tensor<8x8xf32, #blocked>, tensor<8x8xf32, #blocked1> + } +} + +// ----- + +// Test that when the block arg is used outside the slice being rematerialized, +// we are able to propagate that information back to the loop operands. This +// prevents eliminating the convert_layout, because the cost is too high. + +// CHECK-LABEL: remat_for_iter_arg_multi_use_negative +// CHECK: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remat_for_iter_arg_multi_use_negative(%arg0: !tt.ptr, %arg1: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %ptrs = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr, #blocked> + %0 = tt.load %ptrs : tensor<8x8x!tt.ptr, #blocked> + %1 = math.exp %0 : tensor<8x8xf32, #blocked> + %2 = math.exp %1 : tensor<8x8xf32, #blocked> + %3 = math.exp %2 : tensor<8x8xf32, #blocked> + %4 = math.exp %3 : tensor<8x8xf32, #blocked> + %5 = math.exp %4 : tensor<8x8xf32, #blocked> + %6 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %5) -> (tensor<8x8xf32, #blocked>) : i32 { + %7 = ttg.convert_layout %arg3 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + "use"(%7) : (tensor<8x8xf32, #blocked1>) -> () + "other_use"(%arg3) : (tensor<8x8xf32, #blocked>) -> () + scf.yield %arg3 : tensor<8x8xf32, #blocked> + } + tt.return + } +}