Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,24 @@ bool isRematBeneficial(ConvertLayoutOp convertOp, const SetVector<Value> &slice,
auto *user = use.getOwner();
if (user == convertOp || sliceOps.contains(user))
continue;
// For region branch ops, check whether the values they flow into are in
// the slice or unused instead.
if (isa<RegionBranchTerminatorOpInterface>(user))
user = user->getParentOp();
if (auto rbi = dyn_cast<RegionBranchOpInterface>(user)) {
RegionBranchSuccessorMapping mapping;
rbi.getSuccessorOperandInputMapping(mapping);
auto it = mapping.find(&use);
if (it != mapping.end()) {
// We have found the values this use flows into, check if they are
// used outside the slice.
bool isSliceOnly = llvm::all_of(it->second, [&](Value v) {
return slice.contains(v) || v.use_empty();
});
if (isSliceOnly)
continue;
}
}
nonSliceOnlyValues.insert(v);
break;
}
Expand All @@ -1079,12 +1097,26 @@ bool isRematBeneficial(ConvertLayoutOp convertOp, const SetVector<Value> &slice,
// Expand the set to all transitive operands in the slice.
for (size_t i = 0; i < nonSliceOnlyValues.size(); ++i) {
Value v = nonSliceOnlyValues[i];
if (auto *op = v.getDefiningOp()) {
for (auto operand : op->getOperands())
if (slice.contains(operand))
nonSliceOnlyValues.insert(operand);
auto *op = v.getDefiningOp();
// If the operand is a block argument, get the enclosing op.
op = op ? op : v.getParentBlock()->getParentOp();
if (auto rbi = dyn_cast<RegionBranchOpInterface>(op)) {
// Try to determine the operands that flow into this value, and mark them
// as being used outside the slice.
RegionBranchInverseSuccessorMapping mapping;
rbi.getSuccessorInputOperandMapping(mapping);
auto it = mapping.find(v);
if (it != mapping.end()) {
for (auto tiedOperand : it->second)
if (slice.contains(tiedOperand->get()))
nonSliceOnlyValues.insert(tiedOperand->get());
continue;
}
}
// TODO: Handle block arguments.
// In the general case, propagate to all operands of the op.
for (auto operand : op->getOperands())
if (slice.contains(operand))
nonSliceOnlyValues.insert(operand);
}

int64_t convertLayoutCost =
Expand Down
170 changes: 170 additions & 0 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4303,3 +4303,173 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32,
tt.return %sum_cvt : tensor<2xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
}
}

// -----

// CHECK-LABEL: remat_cycle_single_use
// CHECK-NOT: ttg.convert_layout
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @remat_cycle_single_use(%arg0: !tt.ptr<f32>, %arg1: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<8x8xf32, #blocked>
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<8x8x!tt.ptr<f32>, #blocked>
%1 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %cst) -> (tensor<8x8xf32, #blocked>) : i32 {
%2 = tt.load %0 : tensor<8x8x!tt.ptr<f32>, #blocked>
%3 = math.exp %2 : tensor<8x8xf32, #blocked>
%4 = math.exp %3 : tensor<8x8xf32, #blocked>
%5 = math.exp %4 : tensor<8x8xf32, #blocked>
%6 = math.exp %5 : tensor<8x8xf32, #blocked>
%7 = math.exp %6 : tensor<8x8xf32, #blocked>
%8 = arith.addf %arg3, %7 : tensor<8x8xf32, #blocked>
%9 = ttg.convert_layout %8 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1>
"use"(%9) : (tensor<8x8xf32, #blocked1>) -> ()
scf.yield %8 : tensor<8x8xf32, #blocked>
}
tt.return
}
}

// -----

// CHECK-LABEL: remat_for_iter_arg
// CHECK-NOT: ttg.convert_layout
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @remat_for_iter_arg(%arg0: !tt.ptr<f32>, %arg1: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%ptrs = tt.splat %arg0 : !tt.ptr<f32> -> tensor<8x8x!tt.ptr<f32>, #blocked>
%0 = tt.load %ptrs : tensor<8x8x!tt.ptr<f32>, #blocked>
%1 = math.exp %0 : tensor<8x8xf32, #blocked>
%2 = math.exp %1 : tensor<8x8xf32, #blocked>
%3 = math.exp %2 : tensor<8x8xf32, #blocked>
%4 = math.exp %3 : tensor<8x8xf32, #blocked>
%5 = math.exp %4 : tensor<8x8xf32, #blocked>
%6 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %5) -> (tensor<8x8xf32, #blocked>) : i32 {
%7 = ttg.convert_layout %arg3 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1>
"use"(%7) : (tensor<8x8xf32, #blocked1>) -> ()
scf.yield %arg3 : tensor<8x8xf32, #blocked>
}
tt.return
}
}

// -----

// CHECK-LABEL: remat_if_yield
// CHECK-NOT: ttg.convert_layout
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @remat_if_yield(%arg0: !tt.ptr<f32>, %cond: i1) -> tensor<8x8xf32, #blocked1> {
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<8x8x!tt.ptr<f32>, #blocked>
%1 = tt.load %0 : tensor<8x8x!tt.ptr<f32>, #blocked>
%2 = math.exp %1 : tensor<8x8xf32, #blocked>
%3 = math.exp %2 : tensor<8x8xf32, #blocked>
%4 = math.exp %3 : tensor<8x8xf32, #blocked>
%5 = math.exp %4 : tensor<8x8xf32, #blocked>
%6 = math.exp %5 : tensor<8x8xf32, #blocked>
%7 = scf.if %cond -> tensor<8x8xf32, #blocked> {
scf.yield %6 : tensor<8x8xf32, #blocked>
} else {
scf.yield %6 : tensor<8x8xf32, #blocked>
}
%8 = ttg.convert_layout %7 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1>
tt.return %8: tensor<8x8xf32, #blocked1>
}
}

// -----

// CHECK-LABEL: remat_if_nested_yield
// CHECK-NOT: ttg.convert_layout
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @remat_if_nested_yield(%arg0: !tt.ptr<f32>, %cond1: i1, %cond2: i1) -> tensor<8x8xf32, #blocked1> {
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<8x8x!tt.ptr<f32>, #blocked>
%load = tt.load %0 : tensor<8x8x!tt.ptr<f32>, #blocked>
%outer = scf.if %cond1 -> tensor<8x8xf32, #blocked> {
%inner = scf.if %cond2 -> tensor<8x8xf32, #blocked> {
%1 = math.exp %load : tensor<8x8xf32, #blocked>
%2 = math.exp %1 : tensor<8x8xf32, #blocked>
%3 = math.exp %2 : tensor<8x8xf32, #blocked>
%4 = math.exp %3 : tensor<8x8xf32, #blocked>
%5 = math.exp %4 : tensor<8x8xf32, #blocked>
scf.yield %5 : tensor<8x8xf32, #blocked>
} else {
scf.yield %load : tensor<8x8xf32, #blocked>
}
scf.yield %inner : tensor<8x8xf32, #blocked>
} else {
scf.yield %load : tensor<8x8xf32, #blocked>
}
%cvt = ttg.convert_layout %outer : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1>
tt.return %cvt : tensor<8x8xf32, #blocked1>
}
}

// -----

// Test that when the result of an IfOp is used outside the slice being
// rematerialized, we are able to propagate that information back to the yield
// operands. This prevents eliminating the convert_layout, because the cost is
// too high.

// CHECK-LABEL: remat_if_yield_in_branch_multi_use_negative
// CHECK: ttg.convert_layout
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @remat_if_yield_in_branch_multi_use_negative(%arg0: !tt.ptr<f32>, %cond: i1) -> (tensor<8x8xf32, #blocked>, tensor<8x8xf32, #blocked1>) {
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<8x8x!tt.ptr<f32>, #blocked>
%load = tt.load %0 : tensor<8x8x!tt.ptr<f32>, #blocked>
%result = scf.if %cond -> tensor<8x8xf32, #blocked> {
%1 = math.exp %load : tensor<8x8xf32, #blocked>
%2 = math.exp %1 : tensor<8x8xf32, #blocked>
%3 = math.exp %2 : tensor<8x8xf32, #blocked>
%4 = math.exp %3 : tensor<8x8xf32, #blocked>
%5 = math.exp %4 : tensor<8x8xf32, #blocked>
scf.yield %5 : tensor<8x8xf32, #blocked>
} else {
scf.yield %load : tensor<8x8xf32, #blocked>
}
%cvt = ttg.convert_layout %result : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1>
tt.return %result, %cvt: tensor<8x8xf32, #blocked>, tensor<8x8xf32, #blocked1>
}
}

// -----

// Test that when the block arg is used outside the slice being rematerialized,
// we are able to propagate that information back to the loop operands. This
// prevents eliminating the convert_layout, because the cost is too high.

// CHECK-LABEL: remat_for_iter_arg_multi_use_negative
// CHECK: ttg.convert_layout
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @remat_for_iter_arg_multi_use_negative(%arg0: !tt.ptr<f32>, %arg1: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%ptrs = tt.splat %arg0 : !tt.ptr<f32> -> tensor<8x8x!tt.ptr<f32>, #blocked>
%0 = tt.load %ptrs : tensor<8x8x!tt.ptr<f32>, #blocked>
%1 = math.exp %0 : tensor<8x8xf32, #blocked>
%2 = math.exp %1 : tensor<8x8xf32, #blocked>
%3 = math.exp %2 : tensor<8x8xf32, #blocked>
%4 = math.exp %3 : tensor<8x8xf32, #blocked>
%5 = math.exp %4 : tensor<8x8xf32, #blocked>
%6 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %5) -> (tensor<8x8xf32, #blocked>) : i32 {
%7 = ttg.convert_layout %arg3 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1>
"use"(%7) : (tensor<8x8xf32, #blocked1>) -> ()
"other_use"(%arg3) : (tensor<8x8xf32, #blocked>) -> ()
scf.yield %arg3 : tensor<8x8xf32, #blocked>
}
tt.return
}
}
Loading