Skip to content

Commit efbb570

Browse files
committed
Fixups
1 parent ee29dbe commit efbb570

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,8 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
434434
return rewriter.notifyMatchFailure(writeOp,
435435
kMatchFailureNotSMETileTypeMultiple);
436436

437+
// Note: We also disallow masks where any dimension is larger than 16 as
438+
// that won't be possible to arm_sve.psel.
437439
auto mask = writeOp.getMask();
438440
if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
439441
vectorType.getDimSize(1) > 16)))
@@ -462,9 +464,9 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
462464
rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
463465
rewriter.setInsertionPointToStart(storeLoop.getBody());
464466

465-
// For each tile sub-tile of the multi-tile `vectorType`.
467+
// For each sub-tile of the multi-tile `vectorType`.
466468
auto inputSMETiles = adaptor.getVector();
467-
auto inductionVar = storeLoop.getInductionVar();
469+
auto tileSliceIndex = storeLoop.getInductionVar();
468470
for (auto [index, smeTile] : llvm::enumerate(
469471
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
470472
// The coordinates of the tile within `vectorType`.
@@ -473,7 +475,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
473475

474476
// The current slice of `vectorType` we are processing.
475477
auto sliceIndex =
476-
rewriter.create<arith::AddIOp>(loc, tileRow, inductionVar);
478+
rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
477479

478480
// Where in the destination memref the current slice will be stored.
479481
auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
@@ -491,9 +493,10 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
491493
loc, sliceMaskType, sliceMask, smeTile.col);
492494
}
493495

494-
// Extract and store the current slice slice.
496+
// Extract and store the current slice.
495497
Value tile = inputSMETiles[index];
496-
auto slice = rewriter.create<vector::ExtractOp>(loc, tile, inductionVar);
498+
auto slice =
499+
rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
497500
rewriter.create<vector::TransferWriteOp>(
498501
loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
499502
AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector
182182
// CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
183183
// CHECK-NEXT: vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
184184
// CHECK-NEXT: %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
185-
// CHECK-NEXT: %[[BOTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
186-
// CHECK-NEXT: vector.transfer_write %[[BOTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
185+
// CHECK-NEXT: %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
186+
// CHECK-NEXT: vector.transfer_write %[[BOTTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
187187
// CHECK-NEXT: }
188188
// CHECK-NEXT: return
189189
%c0 = arith.constant 0 : index
@@ -277,6 +277,20 @@ func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32
277277

278278
// -----
279279

280+
// Masked writes where any dimension of the mask is > 16 are not supported for the store loop lowering.
281+
282+
// CHECK-LABEL: @negative_transfer_write_f32_scalable_32x32
283+
// CHECK-NOT: scf.for
284+
func.func @negative_transfer_write_f32_scalable_32x32(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[32]x[32]xf32>)
285+
{
286+
%c0 = arith.constant 0 : index
287+
%mask = vector.create_mask %dim0, %dim1 : vector<[32]x[32]xi1>
288+
vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[32]x[32]xf32>, memref<?x?xf32>
289+
return
290+
}
291+
292+
// -----
293+
280294
#transpose = affine_map<(d0, d1) -> (d1, d0)>
281295

282296
// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(

0 commit comments

Comments
 (0)