@@ -376,27 +376,36 @@ struct LegalizeTransferWriteOpsByDecomposition
376376
377377// / Legalize a multi-tile transfer_write as a single store loop. This is done as
378378// / part of type decomposition as at this level we know each tile write is
379- // / disjoint, but that information is lost after decomposition (without
380- // / static analysis ).
379+ // / disjoint, but that information is lost after decomposition (without analysis
380+ // / to reconstruct it ).
381381// /
382- // / Example (in pseudo-MLIR) :
382+ // / Example:
383383// /
384384// / ```
385- // / vector.transfer_write vector, dest[x, y ], mask
386- // / : vector<[16]x[4]xf32 >, memref<?x?xf32 >
385+ // / vector.transfer_write % vector, % dest[%y, %x ], % mask
386+ // / : vector<[16]x[8]xi16 >, memref<?x?xi16 >
387387// / ```
388388// / Is rewritten to:
389389// / ```
390- // / for i in range (0, 4 * vscale) {
391- // / let sliceRow = i + tile_n.row * vscale; ─┐
392- // / let sliceCol = tile_n.col * vscale; |
393- // / slice = vector.extract tile_n[i] |
394- // / : vector<[4]xf32> from vector<[16]x[4]xf32> |
395- // / slice_mask = vector.extract mask[sliceRow] |- Repeated 4x for
396- // / : vector<[4]xi1> from vector<[16]x[4]xi1> | all tiles in
397- // / vector.transfer_write | [16]x[4]
398- // / slice, dest[x + sliceRow, y + sliceCol], slice_mask |
399- // / : vector<[4]xf32>, memref<?x?xf32> ┘
390+ // / scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
391+ // / %upper_slice_y = arith.addi %slice_idx, %y : index
392+ // / %upper_slice_mask = vector.extract %mask[%slice_idx]
393+ // / : vector<[8]xi1> from vector<[16]x[8]xi1>
394+ // / %upper_slice = vector.extract %upper_tile[%slice_idx]
395+ // / : vector<[8]xi16> from vector<[8]x[8]xi16>
396+ // / vector.transfer_write %upper_slice,
397+ // / %dest[%upper_slice_y, %x], %upper_slice_mask
398+ // / : vector<[8]xi16>, memref<?x?xi16>
399+ // / // Same again for the lower tile:
400+ // / %lower_slice_idx = arith.addi %c8_vscale, %slice_idx : index
401+ // / %lower_slice_y = arith.addi %lower_slice_idx, %y : index
402+ // / %lower_slice_mask = vector.extract %mask[%lower_slice_idx]
403+ // / : vector<[8]xi1> from vector<[16]x[8]xi1>
404+ // / %lower_slice = vector.extract %lower_tile[%slice_idx]
405+ // / : vector<[8]xi16> from vector<[8]x[8]xi16>
406+ // / vector.transfer_write %lower_slice,
407+ // / %dest[%lower_slice_y, %x], %lower_slice_mask
408+ // / : vector<[8]xi16>, memref<?x?xi16>
400409// / }
401410// / ```
402411struct LegalizeMultiTileTransferWriteAsStoreLoop
0 commit comments