Skip to content

Conversation

@egebeysel
Copy link
Contributor

This PR builds upon the previous #146531 and enables scalable vectorization for batch_mmt4d as well.

@llvmbot
Copy link
Member

llvmbot commented Aug 11, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Ege Beysel (egebeysel)

Changes

This PR builds upon the previous #146531 and enables scalable vectorization for batch_mmt4d as well.


Full diff: https://github.com/llvm/llvm-project/pull/152984.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir (+94)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index cf65e673a5c44..6a6258f0f6236 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2615,6 +2615,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
                  isa<linalg::MatmulTransposeAOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
                  isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+                 isa<linalg::BatchMmt4DOp>(op) ||
                  hasReductionIterator(linalgOp));
 }
 
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 095810fe0451e..1ee1b4da7dfbc 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -933,6 +933,100 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.batch_batch_mmt4d
+///----------------------------------------------------------------------------------------
+
+func.func @batch_mmt4d(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x8x1xf32>, %C_in: memref<2x16x16x8x8xf32>) {
+  linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x8x1xf32>)
+               outs(%C_in: memref<2x16x16x8x8xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @batch_mmt4d(
+// CHECK-SAME:      %[[A:.*]]: memref<2x16x16x8x1xf32>, %[[B:.*]]: memref<2x16x16x8x1xf32>, %[[C:.*]]: memref<2x16x16x8x8xf32>) {
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<2x16x16x8x8xf32>, vector<2x16x16x8x8xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x8x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x8x1xf32> to vector<2x16x16x8x8xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<2x16x16x8x8xf32>, memref<2x16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %batch_mmt4d : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @batch_mmt4d_scalable(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x?x1xf32>, %C_in: memref<2x16x16x8x?xf32>) {
+  linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x?x1xf32>)
+               outs(%C_in: memref<2x16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @batch_mmt4d_scalable(
+// CHECK-SAME:      %[[A:.*]]: memref<2x16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<2x16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[DIM_2:.*]] = memref.dim %[[B]], %[[C3]] : memref<2x16x16x?x1xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_1:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[DIM_2]], %[[VAL_6]] : vector<2x16x16x[4]x1xi1>
+// CHECK:           %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> } : vector<2x16x16x[4]x1xi1> -> vector<2x16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]] : vector<2x16x16x8x[4]xi1>
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> } : vector<2x16x16x8x[4]xi1> -> vector<2x16x16x8x[4]xf32>
+// CHECK:           %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<2x16x16x16x8x[4]x1xi1>
+// CHECK:           %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32> } : vector<2x16x16x16x8x[4]x1xi1> -> vector<2x16x16x8x[4]xf32>
+// CHECK:           vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> } : vector<2x16x16x8x[4]xi1>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %batch_mmt4d vector_sizes [2, 16, 16, 16, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @batch_mmt4d_scalable_with_assume(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x?x1xf32>, %C_in: memref<2x16x16x8x?xf32>) {
+  linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x?x1xf32>)
+               outs(%C_in: memref<2x16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @batch_mmt4d_scalable_with_assume(
+// CHECK-SAME:      %[[A:.*]]: memref<2x16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<2x16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
+// CHECK-NOT:       mask
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
+// CHECK:           %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32>
+// CHECK:           vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %batch_mmt4d vector_sizes [2, 16, 16, 16, 8, [4], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
+    transform.yield
+  }
+}
+
+
 // -----
 
 ///----------------------------------------------------------------------------------------

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

We could be tempted to check that we don't make the batch dim scalable, but that sort of special-casing would probably be too much. Also, wouldn't make much sense at the Linalg level

// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[DIM_2]], %[[VAL_6]] : vector<2x16x16x[4]x1xi1>
// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> } : vector<2x16x16x[4]x1xi1> -> vector<2x16x16x16x8x[4]x1xf32>
// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]] : vector<2x16x16x8x[4]xi1>
// CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> } : vector<2x16x16x8x[4]xi1> -> vector<2x16x16x8x[4]xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I know that you were inspired by my change for linalg.mmt4d, but would you mind replacing vars like VAL_N with something descriptive? This is a nice-to-have 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - also made the mmt4d ones more descriptive while I was at it :)

@rengolin
Copy link
Member

In the future, it would be good to know what's the cost of representing mmt4d and batch_mmt4d as variations on contract, so that we can make them as specializations like we did with the transpose variants.

@egebeysel
Copy link
Contributor Author

In the future, it would be good to know what's the cost of representing mmt4d and batch_mmt4d as variations on contract, so that we can make them as specializations like we did with the transpose variants.

Absolutely! I was just exploring how the mmt4d -> contraction -> outer_product lowering might look like - will play with this later if that's okay.

@egebeysel egebeysel merged commit 8de85e7 into llvm:main Aug 14, 2025
9 checks passed
@egebeysel egebeysel deleted the egb.scalable_vectorization_batch_mmt4d branch August 14, 2025 09:48
@banach-space
Copy link
Contributor

In the future, it would be good to know what's the cost of representing mmt4d and batch_mmt4d as variations on contract, so that we can make them as specializations like we did with the transpose variants.

Why specialization of linalg.contract?

For the transpose variants that was straightforward - we were just dealing with different variants of linalg.matmul and were creating specialisation of that. But with MMT4D, it's less obvious to me. I guess you'd like to see us converging towards the Linalg tree that you proposed:

? So the specialisation tree would be like this:

  • linalg.contract -> linalg.matmul -> linalg.matmul_transpose_a (2 specializations)
  • linalg.contract -> linalg.matmul -> linalg.matmul_transpose_b (2 specializations)
  • linalg.contract -> linalg.mmt4d (1 specialization)
  • linalg.contract -> linalg.batch_mmt4d (1 specialization)

Just making sure that I understand what you meant 😅

@egebeysel , IIUC, you were thinking of vector.contract rather than the Linalg Op Tree? :)

@rengolin
Copy link
Member

That's what I meant. I've misused the term specialization. It's still not clear to me how we do aliasing (wrt pattern matching, class structure, rewrites).

@egebeysel
Copy link
Contributor Author

@egebeysel , IIUC, you were thinking of vector.contract rather than the Linalg Op Tree? :)

Yes, I was thinking of the linalg.mmt4d -> vector.contract lowerings, which I now see was a misunderstanding 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants