Skip to content

Commit

Permalink
[mlir][vector][test] Split tests from vector-transfer-flatten.mlir (#…
Browse files Browse the repository at this point in the history
…102584)

Move tests that exercise DropUnitDimFromElementwiseOps and
DropUnitDimsFromTransposeOp to a dedicated file.

While these patterns are collected under populateFlattenVectorTransferPatterns
(and are tested via -test-vector-transfer-flatten-patterns), they can actually
be tested without the xfer Ops, and hence the split.

Note, this is mostly just moving tests from one file to another. The only real
change is the removal of the following check-lines:

```mlir
//   CHECK-128B-NOT:   memref.collapse_shape
```

These were added specifically to check the "flattening" logic (which introduces
`memref.collapse_shape`). However, these tests were never meant to test that
logic (in fact, that's the reason I am moving them to a different file) and
hence are being removed as copy&paste errors.

I also removed the following TODO:

```mlir
/// TODO: Potential duplication with tests from:
///   * "vector-dropleadunitdim-transforms.mlir"
///   * "vector-transfer-drop-unit-dims-patterns.mlir"
```
I've checked what patterns are triggered in those test files and neither
DropUnitDimFromElementwiseOps nor DropUnitDimsFromTransposeOp does.
  • Loading branch information
banach-space authored Aug 9, 2024
1 parent 6f19a7b commit 5123f2c
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 236 deletions.
209 changes: 209 additions & 0 deletions mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s

///----------------------------------------------------------------------------------------
/// [Pattern: DropUnitDimFromElementwiseOps]
///----------------------------------------------------------------------------------------

func.func @fold_unit_dim_add_basic(%vec : vector<1x8xi32>) -> vector<1x8xi32> {
%res = arith.addi %vec, %vec : vector<1x8xi32>
return %res : vector<1x8xi32>
}
// CHECK-LABEL: func.func @fold_unit_dim_add_basic(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8xi32> to vector<8xi32>
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8xi32> to vector<8xi32>
// CHECK: %[[VAL_3:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8xi32>
// CHECK: return %[[VAL_4]] : vector<1x8xi32>

// -----

func.func @fold_unit_dim_add_leading_and_trailing(%vec : vector<1x8x1xi32>) -> vector<1x8x1xi32> {
%res = arith.addi %vec, %vec : vector<1x8x1xi32>
return %res : vector<1x8x1xi32>
}
// CHECK-LABEL: func.func @fold_unit_dim_add_leading_and_trailing(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x8x1xi32>) -> vector<1x8x1xi32> {
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8x1xi32> to vector<8xi32>
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8x1xi32> to vector<8xi32>
// CHECK: %[[VAL_3:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8x1xi32>
// CHECK: return %[[VAL_4]] : vector<1x8x1xi32>

// -----

func.func @fold_unit_dim_add(%vec_0 : vector<8x1xi32>,
%vec_1 : vector<1x8xi32>) -> vector<8xi32> {
%sc_vec_0 = vector.shape_cast %vec_0 : vector<8x1xi32> to vector<1x8xi32>
%add = arith.addi %sc_vec_0, %vec_1 : vector<1x8xi32>
%res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
return %res : vector<8xi32>
}

// CHECK-LABEL: func.func @fold_unit_dim_add(
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1xi32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> {
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32>
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32>
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
// CHECK: return %[[VAL_4]] : vector<8xi32>

// -----

func.func @fold_unit_dim_mulf(%vec_0 : vector<8x[2]x1xf32>,
%vec_1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
%sc_vec_0 = vector.shape_cast %vec_0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32>
%add = arith.mulf %sc_vec_0, %vec_1 : vector<1x8x[2]xf32>
%res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
return %res : vector<8x[2]xf32>
}

// CHECK-LABEL: func.func @fold_unit_dim_mulf(
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xf32> to vector<8x[2]xf32>
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[2]xf32> to vector<8x[2]xf32>
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
// CHECK: return %[[VAL_4]] : vector<8x[2]xf32>

// -----

func.func @fold_unit_dim_sitofp(%vec : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
%sc_vec_0 = vector.shape_cast %vec : vector<8x[2]x1xi8> to vector<1x8x[2]xi8>
%add = arith.sitofp %sc_vec_0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32>
%res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
return %res : vector<8x[2]xf32>
}

// CHECK-LABEL: func.func @fold_unit_dim_sitofp(
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xi8> to vector<8x[2]xi8>
// CHECK: %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
// CHECK: return %[[VAL_2]] : vector<8x[2]xf32>

// -----

// All shape casts are folded away

func.func @fold_unit_dims_entirely(%vec_0 : vector<8xi32>,
%vec_1 : vector<8xi32>,
%vec_2 : vector<8xi32>) -> vector<8xi32> {
%sc_vec_0 = vector.shape_cast %vec_0 : vector<8xi32> to vector<1x8xi32>
%sc_vec_1 = vector.shape_cast %vec_1 : vector<8xi32> to vector<1x8xi32>
%sc_vec_2 = vector.shape_cast %vec_2 : vector<8xi32> to vector<1x8xi32>
%mul = arith.muli %sc_vec_0, %sc_vec_1 : vector<1x8xi32>
%add = arith.addi %mul, %sc_vec_2 : vector<1x8xi32>
%res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
return %res : vector<8xi32>
}

// CHECK-LABEL: func.func @fold_unit_dims_entirely(
// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi32>, %[[VAL_1:.*]]: vector<8xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
// CHECK: return %[[VAL_4]] : vector<8xi32>

// -----

func.func @fold_inner_unit_dim(%vec_0 : vector<8x1x3xf128>,
%vec_1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
%sc_vec_1 = vector.shape_cast %vec_1 : vector<1x8x3xf128> to vector<8x1x3xf128>
%mul = arith.mulf %vec_0, %sc_vec_1 : vector<8x1x3xf128>
%res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
return %res : vector<8x3xf128>
}

// CHECK-LABEL: func.func @fold_inner_unit_dim(
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
// CHECK: return %[[VAL_4]] : vector<8x3xf128>

// -----

func.func @fold_inner_unit_dim_scalable(%vec_0 : vector<8x1x[1]x3xf128>,
%vec_1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
%sc_vec_1 = vector.shape_cast %vec_1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
%mul = arith.mulf %vec_0, %sc_vec_1 : vector<8x1x[1]x3xf128>
%res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
return %res : vector<8x[1]x3xf128>
}

// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128>

// -----

func.func @fold_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1xf32> {
%0 = arith.mulf %vec, %vec : vector<1x1xf32>
%res = vector.shape_cast %0 : vector<1x1xf32> to vector<1xf32>
return %res : vector<1xf32>
}

// CHECK-LABEL: func.func @fold_all_unit_dims(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32>
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
// CHECK: %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32>
// CHECK: return %[[VAL_3]] : vector<1xf32>

///----------------------------------------------------------------------------------------
/// [Pattern: DropUnitDimsFromTransposeOp]
///----------------------------------------------------------------------------------------

func.func @transpose_with_internal_unit_dims(%vec: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> {
%res = vector.transpose %vec, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
return %res : vector<[4]x1x1x4xf32>
}

// CHECK-LABEL: func.func @transpose_with_internal_unit_dims(
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<[4]x1x1x4xf32>

// -----

func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> vector<1x1x4x2x[1]xf32>
{
%res = vector.transpose %vec, [4, 1, 3, 2, 0] : vector<[1]x1x2x4x1xf32> to vector<1x1x4x2x[1]xf32>
return %res: vector<1x1x4x2x[1]xf32>
}

// CHECK-LABEL: func.func @transpose_with_scalable_unit_dims(
// CHECK-SAME: %[[VEC:.*]]: vector<[1]x1x2x4x1xf32>)
// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %[[VEC]] : vector<[1]x1x2x4x1xf32> to vector<[1]x2x4xf32>
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[DROP_DIMS]], [2, 1, 0] : vector<[1]x2x4xf32> to vector<4x2x[1]xf32>
// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %[[TRANSPOSE]] : vector<4x2x[1]xf32> to vector<1x1x4x2x[1]xf32>
// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<1x1x4x2x[1]xf32>

// -----

func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
%res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32>
return %res : vector<1x1x1xf32>
}
// The `vec` is returned because there are other flattening patterns that fold
// vector.shape_cast ops away.
// CHECK-LABEL: func.func @transpose_with_all_unit_dims
// CHECK-SAME: %[[VEC:.[a-zA-Z0-9]+]]
// CHECK-NEXT: return %[[VEC]]

// -----

func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
%res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
return %res : vector<4x3x2xf32>
}

// CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims
// CHECK-NOT: vector.shape_cast
Loading

0 comments on commit 5123f2c

Please sign in to comment.