-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][linalg] Add support for scalable vectorization of linalg.batch_mmt4d
#152984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Add support for scalable vectorization of linalg.batch_mmt4d
#152984
Conversation
…_mmt4d Signed-off-by: Ege Beysel <[email protected]>
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Ege Beysel (egebeysel) ChangesThis PR builds upon the previous #146531 and enables scalable vectorization for Full diff: https://github.com/llvm/llvm-project/pull/152984.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index cf65e673a5c44..6a6258f0f6236 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2615,6 +2615,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+ isa<linalg::BatchMmt4DOp>(op) ||
hasReductionIterator(linalgOp));
}
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 095810fe0451e..1ee1b4da7dfbc 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -933,6 +933,100 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.batch_batch_mmt4d
+///----------------------------------------------------------------------------------------
+
+func.func @batch_mmt4d(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x8x1xf32>, %C_in: memref<2x16x16x8x8xf32>) {
+ linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x8x1xf32>)
+ outs(%C_in: memref<2x16x16x8x8xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @batch_mmt4d(
+// CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>, %[[B:.*]]: memref<2x16x16x8x1xf32>, %[[C:.*]]: memref<2x16x16x8x8xf32>) {
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<2x16x16x8x8xf32>, vector<2x16x16x8x8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x8x1xf32>
+// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x8x1xf32> to vector<2x16x16x8x8xf32>
+// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<2x16x16x8x8xf32>, memref<2x16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %batch_mmt4d : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @batch_mmt4d_scalable(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x?x1xf32>, %C_in: memref<2x16x16x8x?xf32>) {
+ linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x?x1xf32>)
+ outs(%C_in: memref<2x16x16x8x?xf32>)
+ return
+}
+// CHECK-LABEL: func.func @batch_mmt4d_scalable(
+// CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
+// CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C3]] : memref<2x16x16x?x1xf32>
+// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
+// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[DIM_2]], %[[VAL_6]] : vector<2x16x16x[4]x1xi1>
+// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> } : vector<2x16x16x[4]x1xi1> -> vector<2x16x16x16x8x[4]x1xf32>
+// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]] : vector<2x16x16x8x[4]xi1>
+// CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> } : vector<2x16x16x8x[4]xi1> -> vector<2x16x16x8x[4]xf32>
+// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
+// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<2x16x16x16x8x[4]x1xi1>
+// CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32> } : vector<2x16x16x16x8x[4]x1xi1> -> vector<2x16x16x8x[4]xf32>
+// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> } : vector<2x16x16x8x[4]xi1>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %batch_mmt4d vector_sizes [2, 16, 16, 16, 8, [4], 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @batch_mmt4d_scalable_with_assume(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x?x1xf32>, %C_in: memref<2x16x16x8x?xf32>) {
+ linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x?x1xf32>)
+ outs(%C_in: memref<2x16x16x8x?xf32>)
+ return
+}
+// CHECK-LABEL: func.func @batch_mmt4d_scalable_with_assume(
+// CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
+// CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
+// CHECK-NOT: mask
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
+// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
+// CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
+// CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
+// CHECK: %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32>
+// CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %batch_mmt4d vector_sizes [2, 16, 16, 16, 8, [4], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
+ transform.yield
+ }
+}
+
+
// -----
///----------------------------------------------------------------------------------------
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
We could be tempted to check that we don't make the batch dim scalable, but that sort of special-casing would probably be too much. Also, wouldn't make much sense at the Linalg level
| // CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[DIM_2]], %[[VAL_6]] : vector<2x16x16x[4]x1xi1> | ||
| // CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> } : vector<2x16x16x[4]x1xi1> -> vector<2x16x16x16x8x[4]x1xf32> | ||
| // CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]] : vector<2x16x16x8x[4]xi1> | ||
| // CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> } : vector<2x16x16x8x[4]xi1> -> vector<2x16x16x8x[4]xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] I know that you were inspired by my change for linalg.mmt4d, but would you mind replacing vars like VAL_N with something descriptive? This is a nice-to-have 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done - also made the mmt4d ones more descriptive while I was at it :)
|
In the future, it would be good to know what's the cost of representing |
Absolutely! I was just exploring how the |
Signed-off-by: Ege Beysel <[email protected]>
Why specialization of For the transpose variants that was straightforward - we were just dealing with different variants of ? So the specialisation tree would be like this:
Just making sure that I understand what you meant 😅 @egebeysel , IIUC, you were thinking of |
|
That's what I meant. I've misused the term specialization. It's still not clear to me how we do aliasing (wrt pattern matching, class structure, rewrites). |
Yes, I was thinking of the |
This PR builds upon the previous #146531 and enables scalable vectorization for
batch_mmt4das well.