diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index b595c6dd8a684..96dad6518fec8 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Transforms/OneToNTypeConversion.h" @@ -373,6 +374,139 @@ struct LegalizeTransferWriteOpsByDecomposition } }; +/// Legalize a multi-tile transfer_write as a single store loop. This is done as +/// part of type decomposition as at this level we know each tile write is +/// disjoint, but that information is lost after decomposition (without analysis +/// to reconstruct it). +/// +/// Example (pseudo-MLIR): +/// +/// ``` +/// vector.transfer_write %vector, %dest[%y, %x], %mask +/// : vector<[16]x[8]xi16>, memref +/// ``` +/// Is rewritten to: +/// ``` +/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 { +/// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐ +/// : vector<[8]xi1> from vector<[16]x[8]xi1> | +/// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile +/// : vector<[8]xi16> from vector<[8]x[8]xi16> | +/// vector.transfer_write %upper_slice, | +/// %dest[%slice_idx + %y, %x], %upper_slice_mask | +/// : vector<[8]xi16>, memref ┘ +/// %lower_slice_idx = %slice_idx + %c8_vscale ─┐ +/// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] | +/// : vector<[8]xi1> from vector<[16]x[8]xi1> | +/// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower +/// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile +/// vector.transfer_write %lower_slice, | +/// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask | +/// : vector<[8]xi16>, memref ┘ +/// } +/// ``` +struct LegalizeMultiTileTransferWriteAsStoreLoop + : public OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + if (writeOp.hasPureTensorSemantics()) + return rewriter.notifyMatchFailure( + writeOp, "TODO: tensor semantics are unsupported"); + + auto permutationMap = writeOp.getPermutationMap(); + if (!permutationMap.isPermutation()) + return rewriter.notifyMatchFailure(writeOp, + kMatchFailureNonPermutationMap); + + bool transposed = !permutationMap.isIdentity(); + if (transposed) + return rewriter.notifyMatchFailure(writeOp, + "TODO: transpose unsupported"); + + auto vectorType = writeOp.getVectorType(); + if (!isMultipleOfSMETileVectorType(vectorType)) + return rewriter.notifyMatchFailure(writeOp, + kMatchFailureNotSMETileTypeMultiple); + + // Note: We also disallow masks where any dimension is > 16 because that + // prevents the masking from being lowered to use arm_sve.psel. + auto mask = writeOp.getMask(); + if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 || + vectorType.getDimSize(1) > 16))) + return rewriter.notifyMatchFailure(writeOp, + kMatchFailureUnsupportedMaskOp); + + auto loc = writeOp.getLoc(); + auto vscale = rewriter.create(loc); + auto createVscaleMultiple = [&](int64_t multiplier) { + return rewriter.create( + loc, vscale, + rewriter.create(loc, multiplier)); + }; + + // Get SME tile and slice types. + auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); + auto minTileSlices = smeTileType.getDimSize(0); + VectorType sliceMaskType = + VectorType::get(minTileSlices, rewriter.getI1Type(), true); + + // Create loop over all tile slices. + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = createVscaleMultiple(minTileSlices); + auto step = rewriter.create(loc, 1); + auto storeLoop = + rewriter.create(loc, lowerBound, upperBound, step); + rewriter.setInsertionPointToStart(storeLoop.getBody()); + + // For each sub-tile of the multi-tile `vectorType`. + auto inputSMETiles = adaptor.getVector(); + auto tileSliceIndex = storeLoop.getInductionVar(); + for (auto [index, smeTile] : llvm::enumerate( + decomposeToSMETiles(rewriter, vectorType, smeTileType))) { + // The coordinates of the tile within `vectorType`. + auto tileRow = createVscaleMultiple(smeTile.row); + auto tileCol = createVscaleMultiple(smeTile.col); + + // The current slice of `vectorType` we are processing. + auto sliceIndex = + rewriter.create(loc, tileRow, tileSliceIndex); + + // Where in the destination memref the current slice will be stored. + auto storeRow = rewriter.create(loc, sliceIndex, + writeOp.getIndices()[0]); + auto storeCol = + rewriter.create(loc, tileCol, writeOp.getIndices()[1]); + + // Extract the mask for the current slice. + Value sliceMask = nullptr; + if (mask) { + sliceMask = rewriter.create( + loc, mask, OpFoldResult(sliceIndex)); + if (sliceMaskType != sliceMask.getType()) + sliceMask = rewriter.create( + loc, sliceMaskType, sliceMask, smeTile.col); + } + + // Extract and store the current slice. + Value tile = inputSMETiles[index]; + auto slice = + rewriter.create(loc, tile, tileSliceIndex); + rewriter.create( + loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol}, + AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)), + sliceMask, + rewriter.getBoolArrayAttr( + ArrayRef(writeOp.getInBoundsValues()).drop_front())); + } + + rewriter.eraseOp(writeOp); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ArmSME-specific fixup canonicalizations/folds //===----------------------------------------------------------------------===// @@ -663,9 +797,12 @@ struct VectorLegalizationPass patterns.add(context); - // Note: High benefit to ensure masked outer products are lowered first. - patterns.add( - converter, context, 1024); + // Note: These two patterns are added with a high benefit to ensure: + // - Masked outer products are handled before unmasked ones + // - Multi-tile writes are lowered as a store loop (if possible) + patterns.add(converter, context, + /*benefit=*/1024); patterns.add, %dim0: func.func @transfer_write_f16_scalable_16x8(%dest: memref, %vec: vector<[16]x[8]xf16>) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index - // CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref - // CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref + // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] { + // CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16> + // CHECK-NEXT: vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref + // CHECK-NEXT: %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index + // CHECK-NEXT: %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16> + // CHECK-NEXT: vector.transfer_write %[[BOTTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref + // CHECK-NEXT: } // CHECK-NEXT: return %c0 = arith.constant 0 : index vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref @@ -201,6 +207,90 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref, %vec: // ----- +// CHECK-LABEL: @transfer_write_f32_scalable_8x8_masked( +// CHECK-SAME: %[[DEST:[a-z0-9]+]]: memref, +// CHECK-SAME: %[[DIM_0:[a-z0-9]+]]: index, +// CHECK-SAME: %[[DIM_1:[a-z0-9]+]]: index, +// CHECK-SAME: %[[TILE_0:[a-z0-9]+]]: vector<[4]x[4]xf32>, +// CHECK-SAME: %[[TILE_1:[a-z0-9]+]]: vector<[4]x[4]xf32>, +// CHECK-SAME: %[[TILE_2:[a-z0-9]+]]: vector<[4]x[4]xf32>, +// CHECK-SAME: %[[TILE_3:[a-z0-9]+]]: vector<[4]x[4]xf32>) +func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>) +{ + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale + // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1> + // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] { + // CHECK-NEXT: %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1> + // CHECK-NEXT: %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1> + // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref + // CHECK-NEXT: %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1> + // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref + // CHECK-NEXT: %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index + // CHECK-NEXT: %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1> + // CHECK-NEXT: %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1> + // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref + // CHECK-NEXT: %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1> + // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref + // CHECK-NEXT: } + %c0 = arith.constant 0 : index + %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1> + vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref + return +} + +// ----- + +// Tensor semantics are not supported for the store loop lowering. + +// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor +// CHECK-NOT: scf.for +func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor, %vec: vector<[8]x[8]xf32>) +{ + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor + return +} + +// ----- + +#transpose = affine_map<(d0, d1) -> (d1, d0)> + +// Transposes are not supported for the store loop lowering. + +// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor +// CHECK-NOT: scf.for +func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>) +{ + %c0 = arith.constant 0 : index + %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1> + vector.transfer_write %vec, %dest[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor + return +} + +// ----- + +// Masked writes where any dimension of the mask is > 16 are not supported for the store loop lowering. + +// CHECK-LABEL: @negative_transfer_write_f32_scalable_32x32 +// CHECK-NOT: scf.for +func.func @negative_transfer_write_f32_scalable_32x32(%dest: memref, %dim0: index, %dim1: index, %vec: vector<[32]x[32]xf32>) +{ + %c0 = arith.constant 0 : index + %mask = vector.create_mask %dim0, %dim1 : vector<[32]x[32]xi1> + vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[32]x[32]xf32>, memref + return +} + +// ----- + #transpose = affine_map<(d0, d1) -> (d1, d0)> // CHECK-LABEL: @transpose_f32_scalable_4x16_via_read( @@ -209,6 +299,7 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref, %vec: func.func @transpose_f32_scalable_4x16_via_read(%src: memref, %dest: memref) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index @@ -221,10 +312,19 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref, %dest: me // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref, vector<[4]x[4]xf32> // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref, vector<[4]x[4]xf32> // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref, vector<[4]x[4]xf32> - // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref - // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref - // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref - // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref + // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] { + // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref + // CHECK-NEXT: %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index + // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref + // CHECK-NEXT: %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index + // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref + // CHECK-NEXT: %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index + // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref + // CHECK-NEXT: } // CHECK-NEXT: return %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f32 diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir index ada744b322fe9..03a7d25cffa76 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir @@ -1,7 +1,8 @@ // RUN: mlir-opt %s \ // RUN: -transform-interpreter -test-transform-dialect-erase-schedule \ // RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \ -// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \ +// RUN: -test-lower-to-arm-sme -convert-vector-to-llvm="enable-arm-sve" \ +// RUN: -test-lower-to-llvm | \ // RUN: %mcr_aarch64_cmd \ // RUN: -e=main -entry-point-result=void \ // RUN: -march=aarch64 -mattr="+sve,+sme" \