Skip to content

Commit

Permalink
Pattern to fuse/fold TFL_TransposeOp into TFL_BatchMatMulOp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 546444604
  • Loading branch information
tensorflower-gardener committed Jul 8, 2023
1 parent 3c9de93 commit 14df49d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tensorflow/compiler/mlir/lite/tests/optimize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,26 @@ func.func @FuseReshapeAroundBMMRHS(%arg0: tensor<1x3x6x5x1024xf32>) -> tensor<1x
// CHECK: return %0 : tensor<1x3x6x5x8192xf32>
}

// CHECK-LABEL: @FuseTransposeIntoBMM_RHS
func.func @FuseTransposeIntoBMM_RHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<1x1440x256xf32>) -> tensor<1x4x1440x1440xf32> {
%cst_1 = arith.constant dense_resource<__elided__> : tensor<3xi32>
%32 = "tfl.transpose"(%arg1, %cst_1) : (tensor<1x1440x256xf32>, tensor<3xi32>) -> tensor<1x256x1440xf32>
%33 = "tfl.batch_matmul"(%arg0, %32) {adj_x = false, adj_y = false} : (tensor<1x4x1440x256xf32>, tensor<1x256x1440xf32>) -> tensor<1x4x1440x1440xf32>
return %33 : tensor<1x4x1440x1440xf32>
// CHECK: %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x1440x256xf32>, tensor<1x1440x256xf32>) -> tensor<1x4x1440x1440xf32>
// CHECK: return %0 : tensor<1x4x1440x1440xf32>
}

// CHECK-LABEL: @FuseTransposeIntoBMM_LHS
func.func @FuseTransposeIntoBMM_LHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<1x1440x256xf32>) -> tensor<1x4x256x256xf32> {
%cst_1 = arith.constant dense_resource<__elided__> : tensor<3xi32>
%32 = "tfl.transpose"(%arg1, %cst_1) : (tensor<1x1440x256xf32>, tensor<3xi32>) -> tensor<1x256x1440xf32>
%33 = "tfl.batch_matmul"(%32, %arg0) {adj_x = false, adj_y = false} : (tensor<1x256x1440xf32>, tensor<1x4x1440x256xf32>) -> tensor<1x4x256x256xf32>
return %33 : tensor<1x4x256x256xf32>
// CHECK: %0 = "tfl.batch_matmul"(%arg1, %arg0) {adj_x = true, adj_y = false} : (tensor<1x1440x256xf32>, tensor<1x4x1440x256xf32>) -> tensor<1x4x256x256xf32>
// CHECK: return %0 : tensor<1x4x256x256xf32>
}

// CHECK-LABEL: @FuseFullyConnectedReshapeAddConst
// FOLD-LABEL: @FuseFullyConnectedReshapeAddConst
func.func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/compiler/mlir/lite/transforms/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ bool BroadcastDimsProductEqual(Value input, Value output,
return (agg_value == output_shape[agg_start_idx]);
}

// Return true if the product of dimension values of a subsection of the tensor
// is equal to the non-contracting dimension after a reshape
bool AreLastTwoDimsTransposed(Value input, Value output) {
ArrayRef<int64_t> input_shape = input.getType().cast<ShapedType>().getShape();
ArrayRef<int64_t> output_shape =
output.getType().cast<ShapedType>().getShape();

return (input_shape.back() == output_shape[output_shape.size() - 2]) &&
(input_shape[input_shape.size() - 2] == output_shape.back());
}

// Returns whether the given type `a` is broadcast-compatible with `b`.
bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
return OpTrait::util::getBroadcastedType(a, b) != Type();
Expand Down
26 changes: 26 additions & 0 deletions tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class HasRank<int n> : Constraint<
class FloatValueEquals<string val> : Constraint<CPred<
"FloatValueEquals($0, " # val # ")">>;

class IsBoolAttrEqual<string true_or_false> : Constraint<CPred<
"$0.getValue() == "#true_or_false#"">>;

// Flattens a constant tensor to 1D.
def FlattenTo1D : NativeCodeCall<"FlattenTo1D($0)">;
Expand Down Expand Up @@ -1480,3 +1482,27 @@ def FuseReshapesAroundBatchMatMulLHS1: Pat<
(BroadcastDimsProductEqual<1> $input, $initial_shape_change),
(BroadcastDimsProductEqual<1> $final_shape_change, $bmm_tmp_output),
(AreTensorSubSectionShapesEqual<1, 1> $input, $final_shape_change)]>;

def AreLastTwoDimsTransposed : Constraint<CPred<
"TFL::AreLastTwoDimsTransposed($0, $1)">>;

// Fuse redundant TFL_TransposeOp into TFL_BatchMatMulOp
def FuseTransposeIntoBatchMatMulRHS: Pat<
(TFL_BatchMatMulOp $lhs,
(TFL_TransposeOp:$transposed_value $input, (Arith_ConstantOp $p0)),
$adj_x, $adj_y, $asymmetric_quantize_inputs),
(TFL_BatchMatMulOp $lhs, $input, $adj_x, ConstBoolAttrTrue, $asymmetric_quantize_inputs),
[(AreLastTwoDimsTransposed $input, $transposed_value),
(IsBoolAttrEqual<"false"> $adj_y),
(AreTensorSubSectionShapesEqual<0, 2> $input, $transposed_value)]>;

// Fuse redundant TFL_TransposeOp into TFL_BatchMatMulOp
def FuseTransposeIntoBatchMatMulLHS: Pat<
(TFL_BatchMatMulOp
(TFL_TransposeOp:$transposed_value $input, (Arith_ConstantOp $p0)),
$rhs, $adj_x, $adj_y, $asymmetric_quantize_inputs),
(TFL_BatchMatMulOp $input, $rhs, ConstBoolAttrTrue, $adj_y, $asymmetric_quantize_inputs),
[(AreLastTwoDimsTransposed $input, $transposed_value),
(IsBoolAttrEqual<"false"> $adj_x),
(AreTensorSubSectionShapesEqual<0, 2> $input, $transposed_value)]>;

0 comments on commit 14df49d

Please sign in to comment.