@@ -434,6 +434,8 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
434434 return rewriter.notifyMatchFailure (writeOp,
435435 kMatchFailureNotSMETileTypeMultiple );
436436
437+ // Note: We also disallow masks where any dimension is larger than 16 as
438+ // that won't be possible to arm_sve.psel.
437439 auto mask = writeOp.getMask ();
438440 if (!isSupportedMaskOp (mask) || (mask && (vectorType.getDimSize (0 ) > 16 ||
439441 vectorType.getDimSize (1 ) > 16 )))
@@ -462,9 +464,9 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
462464 rewriter.create <scf::ForOp>(loc, lowerBound, upperBound, step);
463465 rewriter.setInsertionPointToStart (storeLoop.getBody ());
464466
465- // For each tile sub-tile of the multi-tile `vectorType`.
467+ // For each sub-tile of the multi-tile `vectorType`.
466468 auto inputSMETiles = adaptor.getVector ();
467- auto inductionVar = storeLoop.getInductionVar ();
469+ auto tileSliceIndex = storeLoop.getInductionVar ();
468470 for (auto [index, smeTile] : llvm::enumerate (
469471 decomposeToSMETiles (rewriter, vectorType, smeTileType))) {
470472 // The coordinates of the tile within `vectorType`.
@@ -473,7 +475,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
473475
474476 // The current slice of `vectorType` we are processing.
475477 auto sliceIndex =
476- rewriter.create <arith::AddIOp>(loc, tileRow, inductionVar );
478+ rewriter.create <arith::AddIOp>(loc, tileRow, tileSliceIndex );
477479
478480 // Where in the destination memref the current slice will be stored.
479481 auto storeRow = rewriter.create <arith::AddIOp>(loc, sliceIndex,
@@ -491,9 +493,10 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
491493 loc, sliceMaskType, sliceMask, smeTile.col );
492494 }
493495
494- // Extract and store the current slice slice .
496+ // Extract and store the current slice.
495497 Value tile = inputSMETiles[index];
496- auto slice = rewriter.create <vector::ExtractOp>(loc, tile, inductionVar);
498+ auto slice =
499+ rewriter.create <vector::ExtractOp>(loc, tile, tileSliceIndex);
497500 rewriter.create <vector::TransferWriteOp>(
498501 loc, slice, writeOp.getSource (), ValueRange{storeRow, storeCol},
499502 AffineMapAttr::get (writeOp.getPermutationMap ().dropResult (0 )),
0 commit comments