Skip to content

Commit

Permalink
Add folding for BroadcastOp along with a test.
Browse files Browse the repository at this point in the history
  • Loading branch information
uazizTT committed Jan 14, 2025
1 parent f11ee52 commit 026b731
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 58 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,7 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {
}];

let hasVerifier = 1;
let hasFolder = 1;
}

def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,14 @@ ::mlir::LogicalResult mlir::tt::ttir::BroadcastOp::verify() {
return success();
}

// BroadcastOp folder
::mlir::OpFoldResult mlir::tt::ttir::BroadcastOp::fold(FoldAdaptor adaptor) {
if (getType() == getOperand(0).getType()) {
return getOperand(0);
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
Expand Down
69 changes: 40 additions & 29 deletions test/ttmlir/Dialect/TTNN/simple_repeat.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,47 @@ module {
}

module {
func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x23x40x128xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array<i32 : 1, 1, 1, 128>}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%2 = tensor.empty() : tensor<1x1x1x128xf32>
%3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32>
%4 = tensor.empty() : tensor<1x23x40x128xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 1, 23, 40, 1>}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%6 = tensor.empty() : tensor<1x23x40x128xf32>
%7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
return %7 : tensor<1x23x40x128xf32>
}
func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x23x40x128xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array<i32 : 1, 1, 1, 128>}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%2 = tensor.empty() : tensor<1x1x1x128xf32>
%3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32>
%4 = tensor.empty() : tensor<1x23x40x128xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 1, 23, 40, 1>}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%6 = tensor.empty() : tensor<1x23x40x128xf32>
%7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
return %7 : tensor<1x23x40x128xf32>
}
}

module {
func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x6x2xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32>
%2 = tensor.empty() : tensor<1x6x1x2xf32>
%3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32>
%4 = tensor.empty() : tensor<400x6x1x2xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32: 400, 1, 1, 1>}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32>
%6 = tensor.empty() : tensor<2400x1x2xf32>
%7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32>
%8 = tensor.empty() : tensor<2400x2xf32>
%9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32>
return %9 : tensor<2400x2xf32>
}
func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x6x2xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32>
%2 = tensor.empty() : tensor<1x6x1x2xf32>
%3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32>
%4 = tensor.empty() : tensor<400x6x1x2xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32: 400, 1, 1, 1>}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32>
%6 = tensor.empty() : tensor<2400x1x2xf32>
%7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32>
%8 = tensor.empty() : tensor<2400x2xf32>
%9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32>
return %9 : tensor<2400x2xf32>
}
}

module {
func.func public @main(%arg0: tensor<512x512xf32>, %arg1: tensor<512x512xf32>) -> (tensor<512x512xf32>) {
// CHECK-NOT: "ttnn.repeat"
%1 = tensor.empty() : tensor<512x512xf32>
%2 = "ttir.broadcast"(%arg0, %1) <{broadcast_dimensions = array<i32 : 1, 1>}> : (tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
%3 = tensor.empty() : tensor<512x512xf32>
%4 = "ttir.maximum"(%3, %arg1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
return %4 : tensor<512x512xf32>
}
}
69 changes: 40 additions & 29 deletions test/ttmlir/Silicon/TTNN/simple_repeat.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,47 @@ module {
}

module {
func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x23x40x128xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array<i32 : 1, 1, 1, 128>}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%2 = tensor.empty() : tensor<1x1x1x128xf32>
%3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32>
%4 = tensor.empty() : tensor<1x23x40x128xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 1, 23, 40, 1>}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%6 = tensor.empty() : tensor<1x23x40x128xf32>
%7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
return %7 : tensor<1x23x40x128xf32>
}
func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x23x40x128xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array<i32 : 1, 1, 1, 128>}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%2 = tensor.empty() : tensor<1x1x1x128xf32>
%3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32>
%4 = tensor.empty() : tensor<1x23x40x128xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 1, 23, 40, 1>}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%6 = tensor.empty() : tensor<1x23x40x128xf32>
%7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
return %7 : tensor<1x23x40x128xf32>
}
}

module {
func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x6x2xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32>
%2 = tensor.empty() : tensor<1x6x1x2xf32>
%3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32>
%4 = tensor.empty() : tensor<400x6x1x2xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 400, 1, 1, 1>}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32>
%6 = tensor.empty() : tensor<2400x1x2xf32>
%7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32>
%8 = tensor.empty() : tensor<2400x2xf32>
%9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32>
return %9 : tensor<2400x2xf32>
}
func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x6x2xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32>
%2 = tensor.empty() : tensor<1x6x1x2xf32>
%3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32>
%4 = tensor.empty() : tensor<400x6x1x2xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 400, 1, 1, 1>}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32>
%6 = tensor.empty() : tensor<2400x1x2xf32>
%7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32>
%8 = tensor.empty() : tensor<2400x2xf32>
%9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32>
return %9 : tensor<2400x2xf32>
}
}

module {
func.func public @main(%arg0: tensor<512x512xf32>, %arg1: tensor<512x512xf32>) -> (tensor<512x512xf32>) {
// CHECK-NOT: "ttnn.repeat"
%1 = tensor.empty() : tensor<512x512xf32>
%2 = "ttir.broadcast"(%arg0, %1) <{broadcast_dimensions = array<i32 : 1, 1>}> : (tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
%3 = tensor.empty() : tensor<512x512xf32>
%4 = "ttir.maximum"(%3, %arg1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
return %4 : tensor<512x512xf32>
}
}

0 comments on commit 026b731

Please sign in to comment.