Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class LayoutRematerialization {
}

void cleanup();
void backwardRematerialization();
bool backwardRematerialization();
void backwardRematerialization(ConvertLayoutOp convertOp);
// TODO: Merge the three hoistConvert*(); functions as they are duplicate code
void hoistConvertDotOperand();
Expand Down Expand Up @@ -1019,7 +1019,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
return success();
}

void LayoutRematerialization::backwardRematerialization() {
bool LayoutRematerialization::backwardRematerialization() {
bool changed = false;
// Go through each ConvertLayoutOp.
SmallVector<ConvertLayoutOp> convertOps;
funcOp.walk(
Expand All @@ -1031,8 +1032,11 @@ void LayoutRematerialization::backwardRematerialization() {
// backward slices.
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
convertOp.getResult());
} else {
changed = true;
}
}
return changed;
}

void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
Expand Down Expand Up @@ -1593,12 +1597,14 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
rewriteSlice(slice, layout, convertOp, mapping);
}

void backwardRematerialization(ModuleOp module) {
module.walk([](FuncOp funcOp) {
bool backwardRematerialization(ModuleOp module) {
bool changed = false;
module.walk([&](FuncOp funcOp) {
LayoutRematerialization layoutRemat(funcOp);
layoutRemat.backwardRematerialization();
changed |= layoutRemat.backwardRematerialization();
layoutRemat.cleanup();
});
return changed;
}

void hoistConvert(ModuleOp module) {
Expand Down Expand Up @@ -1659,17 +1665,20 @@ class TritonGPURemoveLayoutConversionsPass

cleanupConvertOps();

// 2. For remaining convert ops, try to rematerialize the slice of producer
// operation to avoid having to convert.
backwardRematerialization(m);
LLVM_DEBUG({
DBGS() << "Module after backward remat:\n";
m.dump();
});

// Cleanup dummy converts created during backward remat.
cleanupConvertOps();

bool changed = false;
do {
changed = false;
// 2. For remaining convert ops, try to rematerialize the slice of
// producer operation to avoid having to convert.
changed = backwardRematerialization(m);
LLVM_DEBUG({
DBGS() << "Module after backward remat:\n";
m.dump();
});

// Cleanup dummy converts created during backward remat.
cleanupConvertOps();
} while (changed);
// 3. For remaining converts, try to hoist them above cast generating larger
// size types in order to reduce the cost of the convert op.
hoistConvert(m);
Expand Down
7 changes: 3 additions & 4 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2500,11 +2500,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<64x64x!tt.ptr<f16>, #blocked2>
%3 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked>
%4 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked>
// CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>)
// FIXME: The optimal number of conversions should be 4.
// CHECK-COUNT-5: convert_layout
// CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well this is a good sign I guess

// CHECK-COUNT-4: convert_layout
// CHECK-NOT: convert_layout
// CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
// CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
// CHECK: }
// CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
%5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 {
Expand Down
Loading