diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index efe0b890dc3a..3531c0bf6d9d 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -127,7 +127,7 @@ class LayoutRematerialization { } void cleanup(); - void backwardRematerialization(); + bool backwardRematerialization(); void backwardRematerialization(ConvertLayoutOp convertOp); // TODO: Merge the three hoistConvert*(); functions as they are duplicate code void hoistConvertDotOperand(); @@ -1019,7 +1019,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice( return success(); } -void LayoutRematerialization::backwardRematerialization() { +bool LayoutRematerialization::backwardRematerialization() { + bool changed = false; // Go through each ConvertLayoutOp. SmallVector convertOps; funcOp.walk( @@ -1031,8 +1032,11 @@ void LayoutRematerialization::backwardRematerialization() { // backward slices. addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), convertOp.getResult()); + } else { + changed = true; } } + return changed; } void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { @@ -1593,12 +1597,14 @@ void LayoutRematerialization::hoistConvertIntoConditionals( rewriteSlice(slice, layout, convertOp, mapping); } -void backwardRematerialization(ModuleOp module) { - module.walk([](FuncOp funcOp) { +bool backwardRematerialization(ModuleOp module) { + bool changed = false; + module.walk([&](FuncOp funcOp) { LayoutRematerialization layoutRemat(funcOp); - layoutRemat.backwardRematerialization(); + changed |= layoutRemat.backwardRematerialization(); layoutRemat.cleanup(); }); + return changed; } void hoistConvert(ModuleOp module) { @@ -1659,17 +1665,20 @@ class TritonGPURemoveLayoutConversionsPass cleanupConvertOps(); - // 2. For remaining convert ops, try to rematerialize the slice of producer - // operation to avoid having to convert. - backwardRematerialization(m); - LLVM_DEBUG({ - DBGS() << "Module after backward remat:\n"; - m.dump(); - }); - - // Cleanup dummy converts created during backward remat. - cleanupConvertOps(); - + bool changed = false; + do { + changed = false; + // 2. For remaining convert ops, try to rematerialize the slice of + // producer operation to avoid having to convert. + changed = backwardRematerialization(m); + LLVM_DEBUG({ + DBGS() << "Module after backward remat:\n"; + m.dump(); + }); + + // Cleanup dummy converts created during backward remat. + cleanupConvertOps(); + } while (changed); // 3. For remaining converts, try to hoist them above cast generating larger // size types in order to reduce the cost of the convert op. hoistConvert(m); diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 997354685f68..5421fa8d19f5 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2500,11 +2500,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %2 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> %3 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> %4 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> - // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) - // FIXME: The optimal number of conversions should be 4. - // CHECK-COUNT-5: convert_layout + // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) + // CHECK-COUNT-4: convert_layout // CHECK-NOT: convert_layout - // CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + // CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> // CHECK: } // CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 {