Skip to content

Commit

Permalink
Refactor BroadcastOp folder and add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
uazizTT committed Jan 15, 2025
1 parent d129a9e commit 8f7139c
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 73 deletions.
1 change: 0 additions & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,6 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {
}];

let hasVerifier = 1;

let hasFolder = 1;
}

Expand Down
24 changes: 10 additions & 14 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,6 @@ void mlir::tt::ttir::BitwiseXorOp::getCanonicalizationPatterns(
});
}

//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//

// BroadcastOp folder
::mlir::OpFoldResult mlir::tt::ttir::BroadcastOp::fold(FoldAdaptor adaptor) {
// If the input doesn't change the shape, we can fold the operation.
if (llvm::all_of(getBroadcastDimensions(),
[](const int32_t dim) { return dim == 1; })) {
return getInput();
}
return {};
}

//===----------------------------------------------------------------------===//
// ClampOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -481,6 +467,16 @@ ::mlir::LogicalResult mlir::tt::ttir::BroadcastOp::verify() {
return success();
}

// BroadcastOp folder
::mlir::OpFoldResult mlir::tt::ttir::BroadcastOp::fold(FoldAdaptor adaptor) {
// If the input doesn't change the shape, we can fold the operation.
if (llvm::all_of(getBroadcastDimensions(),
[](const int32_t dim) { return dim == 1; })) {
return getInput();
}
return {};
}

//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
Expand Down
69 changes: 40 additions & 29 deletions test/ttmlir/Dialect/TTNN/simple_repeat.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,47 @@ module {
}

module {
func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x23x40x128xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array<i32 : 1, 1, 1, 128>}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%2 = tensor.empty() : tensor<1x1x1x128xf32>
%3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32>
%4 = tensor.empty() : tensor<1x23x40x128xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 1, 23, 40, 1>}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%6 = tensor.empty() : tensor<1x23x40x128xf32>
%7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
return %7 : tensor<1x23x40x128xf32>
}
func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x23x40x128xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array<i32 : 1, 1, 1, 128>}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%2 = tensor.empty() : tensor<1x1x1x128xf32>
%3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32>
%4 = tensor.empty() : tensor<1x23x40x128xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 1, 23, 40, 1>}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%6 = tensor.empty() : tensor<1x23x40x128xf32>
%7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
return %7 : tensor<1x23x40x128xf32>
}
}

module {
func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x6x2xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32>
%2 = tensor.empty() : tensor<1x6x1x2xf32>
%3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32>
%4 = tensor.empty() : tensor<400x6x1x2xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32: 400, 1, 1, 1>}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32>
%6 = tensor.empty() : tensor<2400x1x2xf32>
%7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32>
%8 = tensor.empty() : tensor<2400x2xf32>
%9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32>
return %9 : tensor<2400x2xf32>
}
func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x6x2xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32>
%2 = tensor.empty() : tensor<1x6x1x2xf32>
%3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32>
%4 = tensor.empty() : tensor<400x6x1x2xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32: 400, 1, 1, 1>}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32>
%6 = tensor.empty() : tensor<2400x1x2xf32>
%7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32>
%8 = tensor.empty() : tensor<2400x2xf32>
%9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32>
return %9 : tensor<2400x2xf32>
}
}

module {
func.func public @main(%arg0: tensor<512x512xf32>, %arg1: tensor<512x512xf32>) -> (tensor<512x512xf32>) {
// CHECK-NOT: "ttnn.repeat"
%1 = tensor.empty() : tensor<512x512xf32>
%2 = "ttir.broadcast"(%arg0, %1) <{broadcast_dimensions = array<i32 : 1, 1>}> : (tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
%3 = tensor.empty() : tensor<512x512xf32>
%4 = "ttir.maximum"(%3, %arg1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
return %4 : tensor<512x512xf32>
}
}
69 changes: 40 additions & 29 deletions test/ttmlir/Silicon/TTNN/simple_repeat.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,47 @@ module {
}

module {
func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x23x40x128xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array<i32 : 1, 1, 1, 128>}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%2 = tensor.empty() : tensor<1x1x1x128xf32>
%3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32>
%4 = tensor.empty() : tensor<1x23x40x128xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 1, 23, 40, 1>}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%6 = tensor.empty() : tensor<1x23x40x128xf32>
%7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
return %7 : tensor<1x23x40x128xf32>
}
func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x23x40x128xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array<i32 : 1, 1, 1, 128>}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%2 = tensor.empty() : tensor<1x1x1x128xf32>
%3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32>
%4 = tensor.empty() : tensor<1x23x40x128xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 1, 23, 40, 1>}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%6 = tensor.empty() : tensor<1x23x40x128xf32>
%7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
return %7 : tensor<1x23x40x128xf32>
}
}

module {
func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x6x2xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32>
%2 = tensor.empty() : tensor<1x6x1x2xf32>
%3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32>
%4 = tensor.empty() : tensor<400x6x1x2xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 400, 1, 1, 1>}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32>
%6 = tensor.empty() : tensor<2400x1x2xf32>
%7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32>
%8 = tensor.empty() : tensor<2400x2xf32>
%9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32>
return %9 : tensor<2400x2xf32>
}
}

module {
func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x6x2xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32>
%2 = tensor.empty() : tensor<1x6x1x2xf32>
%3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32>
%4 = tensor.empty() : tensor<400x6x1x2xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 400, 1, 1, 1>}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32>
%6 = tensor.empty() : tensor<2400x1x2xf32>
%7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32>
%8 = tensor.empty() : tensor<2400x2xf32>
%9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32>
return %9 : tensor<2400x2xf32>
}
func.func public @main(%arg0: tensor<512x512xf32>, %arg1: tensor<512x512xf32>) -> (tensor<512x512xf32>) {
// CHECK-NOT: "ttnn.repeat"
%1 = tensor.empty() : tensor<512x512xf32>
%2 = "ttir.broadcast"(%arg0, %1) <{broadcast_dimensions = array<i32 : 1, 1>}> : (tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
%3 = tensor.empty() : tensor<512x512xf32>
%4 = "ttir.maximum"(%3, %arg1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
return %4 : tensor<512x512xf32>
}
}

0 comments on commit 8f7139c

Please sign in to comment.