From 8f7139c851af6dd53f5108a92f686e01f0f13cf0 Mon Sep 17 00:00:00 2001 From: uazizTT Date: Tue, 14 Jan 2025 11:01:08 -0500 Subject: [PATCH] Refactor BroadcastOp folder and add tests. --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 1 - lib/Dialect/TTIR/IR/TTIROps.cpp | 24 +++---- test/ttmlir/Dialect/TTNN/simple_repeat.mlir | 69 ++++++++++++--------- test/ttmlir/Silicon/TTNN/simple_repeat.mlir | 69 ++++++++++++--------- 4 files changed, 90 insertions(+), 73 deletions(-) diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 710f88cfef..df01d75e1d 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -864,7 +864,6 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> { }]; let hasVerifier = 1; - let hasFolder = 1; } diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 0b58ed860e..ac7ba6620b 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -61,20 +61,6 @@ void mlir::tt::ttir::BitwiseXorOp::getCanonicalizationPatterns( }); } -//===----------------------------------------------------------------------===// -// BroadcastOp -//===----------------------------------------------------------------------===// - -// BroadcastOp folder -::mlir::OpFoldResult mlir::tt::ttir::BroadcastOp::fold(FoldAdaptor adaptor) { - // If the input doesn't change the shape, we can fold the operation. - if (llvm::all_of(getBroadcastDimensions(), - [](const int32_t dim) { return dim == 1; })) { - return getInput(); - } - return {}; -} - //===----------------------------------------------------------------------===// // ClampOp //===----------------------------------------------------------------------===// @@ -481,6 +467,16 @@ ::mlir::LogicalResult mlir::tt::ttir::BroadcastOp::verify() { return success(); } +// BroadcastOp folder +::mlir::OpFoldResult mlir::tt::ttir::BroadcastOp::fold(FoldAdaptor adaptor) { + // If the input doesn't change the shape, we can fold the operation. + if (llvm::all_of(getBroadcastDimensions(), + [](const int32_t dim) { return dim == 1; })) { + return getInput(); + } + return {}; +} + //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// diff --git a/test/ttmlir/Dialect/TTNN/simple_repeat.mlir b/test/ttmlir/Dialect/TTNN/simple_repeat.mlir index 42261936e5..6309223924 100644 --- a/test/ttmlir/Dialect/TTNN/simple_repeat.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_repeat.mlir @@ -27,36 +27,47 @@ module { } module { - func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> { - // CHECK: %{{[0-9]+}} = "ttnn.reshape" - // CHECK: %{{[0-9]+}} = "ttnn.repeat" - // CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32] - %0 = tensor.empty() : tensor<1x23x40x128xf32> - %1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> - %2 = tensor.empty() : tensor<1x1x1x128xf32> - %3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> - %4 = tensor.empty() : tensor<1x23x40x128xf32> - %5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> - %6 = tensor.empty() : tensor<1x23x40x128xf32> - %7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> - return %7 : tensor<1x23x40x128xf32> - } + func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> { + // CHECK: %{{[0-9]+}} = "ttnn.reshape" + // CHECK: %{{[0-9]+}} = "ttnn.repeat" + // CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32] + %0 = tensor.empty() : tensor<1x23x40x128xf32> + %1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> + %2 = tensor.empty() : tensor<1x1x1x128xf32> + %3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> + %4 = tensor.empty() : tensor<1x23x40x128xf32> + %5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> + %6 = tensor.empty() : tensor<1x23x40x128xf32> + %7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> + return %7 : tensor<1x23x40x128xf32> } +} module { - func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> { - // CHECK: %{{[0-9]+}} = "ttnn.repeat" - // CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32] - %0 = tensor.empty() : tensor<1x6x2xf32> - %1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32> - %2 = tensor.empty() : tensor<1x6x1x2xf32> - %3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32> - %4 = tensor.empty() : tensor<400x6x1x2xf32> - %5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32> - %6 = tensor.empty() : tensor<2400x1x2xf32> - %7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32> - %8 = tensor.empty() : tensor<2400x2xf32> - %9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32> - return %9 : tensor<2400x2xf32> - } + func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> { + // CHECK: %{{[0-9]+}} = "ttnn.repeat" + // CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32] + %0 = tensor.empty() : tensor<1x6x2xf32> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32> + %2 = tensor.empty() : tensor<1x6x1x2xf32> + %3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32> + %4 = tensor.empty() : tensor<400x6x1x2xf32> + %5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32> + %6 = tensor.empty() : tensor<2400x1x2xf32> + %7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32> + %8 = tensor.empty() : tensor<2400x2xf32> + %9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32> + return %9 : tensor<2400x2xf32> } +} + +module { + func.func public @main(%arg0: tensor<512x512xf32>, %arg1: tensor<512x512xf32>) -> (tensor<512x512xf32>) { + // CHECK-NOT: "ttnn.repeat" + %1 = tensor.empty() : tensor<512x512xf32> + %2 = "ttir.broadcast"(%arg0, %1) <{broadcast_dimensions = array}> : (tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> + %3 = tensor.empty() : tensor<512x512xf32> + %4 = "ttir.maximum"(%3, %arg1, %2) <{operandSegmentSizes = array}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> + return %4 : tensor<512x512xf32> + } +} diff --git a/test/ttmlir/Silicon/TTNN/simple_repeat.mlir b/test/ttmlir/Silicon/TTNN/simple_repeat.mlir index ab91af2ee6..a3e947b2ac 100644 --- a/test/ttmlir/Silicon/TTNN/simple_repeat.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_repeat.mlir @@ -27,36 +27,47 @@ module { } module { - func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> { - // CHECK: %{{[0-9]+}} = "ttnn.reshape" - // CHECK: %{{[0-9]+}} = "ttnn.repeat" - // CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32] - %0 = tensor.empty() : tensor<1x23x40x128xf32> - %1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> - %2 = tensor.empty() : tensor<1x1x1x128xf32> - %3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> - %4 = tensor.empty() : tensor<1x23x40x128xf32> - %5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> - %6 = tensor.empty() : tensor<1x23x40x128xf32> - %7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> - return %7 : tensor<1x23x40x128xf32> - } + func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> { + // CHECK: %{{[0-9]+}} = "ttnn.reshape" + // CHECK: %{{[0-9]+}} = "ttnn.repeat" + // CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32] + %0 = tensor.empty() : tensor<1x23x40x128xf32> + %1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> + %2 = tensor.empty() : tensor<1x1x1x128xf32> + %3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> + %4 = tensor.empty() : tensor<1x23x40x128xf32> + %5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> + %6 = tensor.empty() : tensor<1x23x40x128xf32> + %7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> + return %7 : tensor<1x23x40x128xf32> } +} + +module { + func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> { + // CHECK: %{{[0-9]+}} = "ttnn.repeat" + // CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32] + %0 = tensor.empty() : tensor<1x6x2xf32> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32> + %2 = tensor.empty() : tensor<1x6x1x2xf32> + %3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32> + %4 = tensor.empty() : tensor<400x6x1x2xf32> + %5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32> + %6 = tensor.empty() : tensor<2400x1x2xf32> + %7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32> + %8 = tensor.empty() : tensor<2400x2xf32> + %9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32> + return %9 : tensor<2400x2xf32> + } +} module { - func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> { - // CHECK: %{{[0-9]+}} = "ttnn.repeat" - // CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32] - %0 = tensor.empty() : tensor<1x6x2xf32> - %1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32> - %2 = tensor.empty() : tensor<1x6x1x2xf32> - %3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32> - %4 = tensor.empty() : tensor<400x6x1x2xf32> - %5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32> - %6 = tensor.empty() : tensor<2400x1x2xf32> - %7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32> - %8 = tensor.empty() : tensor<2400x2xf32> - %9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32> - return %9 : tensor<2400x2xf32> - } + func.func public @main(%arg0: tensor<512x512xf32>, %arg1: tensor<512x512xf32>) -> (tensor<512x512xf32>) { + // CHECK-NOT: "ttnn.repeat" + %1 = tensor.empty() : tensor<512x512xf32> + %2 = "ttir.broadcast"(%arg0, %1) <{broadcast_dimensions = array}> : (tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> + %3 = tensor.empty() : tensor<512x512xf32> + %4 = "ttir.maximum"(%3, %arg1, %2) <{operandSegmentSizes = array}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> + return %4 : tensor<512x512xf32> } +}