@@ -174,11 +174,17 @@ func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0:
174174func.func @transfer_write_f16_scalable_16x8 (%dest: memref <?x?xf16 >, %vec: vector <[16 ]x[8 ]xf16 >)
175175{
176176 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
177+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
177178 // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
178179 // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
179180 // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
180- // CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
181- // CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
181+ // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] {
182+ // CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
183+ // CHECK-NEXT: vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
184+ // CHECK-NEXT: %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
185+ // CHECK-NEXT: %[[BOTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
186+ // CHECK-NEXT: vector.transfer_write %[[BOTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
187+ // CHECK-NEXT: }
182188 // CHECK-NEXT: return
183189 %c0 = arith.constant 0 : index
184190 vector.transfer_write %vec , %dest [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <[16 ]x[8 ]xf16 >, memref <?x?xf16 >
@@ -201,6 +207,47 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
201207
202208// -----
203209
210+ // CHECK-LABEL: @transfer_write_f32_scalable_8x8_masked(
211+ // CHECK-SAME: %[[DEST:[a-z0-9]+]]: memref<?x?xf32>,
212+ // CHECK-SAME: %[[DIM_0:[a-z0-9]+]]: index,
213+ // CHECK-SAME: %[[DIM_1:[a-z0-9]+]]: index,
214+ // CHECK-SAME: %[[TILE_0:[a-z0-9]+]]: vector<[4]x[4]xf32>,
215+ // CHECK-SAME: %[[TILE_1:[a-z0-9]+]]: vector<[4]x[4]xf32>,
216+ // CHECK-SAME: %[[TILE_2:[a-z0-9]+]]: vector<[4]x[4]xf32>,
217+ // CHECK-SAME: %[[TILE_3:[a-z0-9]+]]: vector<[4]x[4]xf32>)
218+ func.func @transfer_write_f32_scalable_8x8_masked (%dest: memref <?x?xf32 >, %dim0: index , %dim1: index , %vec: vector <[8 ]x[8 ]xf32 >)
219+ {
220+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
221+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
222+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
223+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
224+ // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
225+ // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1>
226+ // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
227+ // CHECK-NEXT: %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
228+ // CHECK-NEXT: %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
229+ // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
230+ // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
231+ // CHECK-NEXT: %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
232+ // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
233+ // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
234+ // CHECK-NEXT: %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
235+ // CHECK-NEXT: %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
236+ // CHECK-NEXT: %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
237+ // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
238+ // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
239+ // CHECK-NEXT: %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
240+ // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
241+ // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
242+ // CHECK-NEXT: }
243+ %c0 = arith.constant 0 : index
244+ %mask = vector.create_mask %dim0 , %dim1 : vector <[8 ]x[8 ]xi1 >
245+ vector.transfer_write %vec , %dest [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
246+ return
247+ }
248+
249+ // -----
250+
204251#transpose = affine_map <(d0 , d1 ) -> (d1 , d0 )>
205252
206253// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
@@ -209,6 +256,7 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
209256func.func @transpose_f32_scalable_4x16_via_read (%src: memref <?x?xf32 >, %dest: memref <?x?xf32 >)
210257{
211258 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
259+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
212260 // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
213261 // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
214262 // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
@@ -221,10 +269,19 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: me
221269 // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
222270 // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
223271 // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
224- // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
225- // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
226- // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
227- // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
272+ // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
273+ // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
274+ // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
275+ // CHECK-NEXT: %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
276+ // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
277+ // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
278+ // CHECK-NEXT: %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
279+ // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
280+ // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
281+ // CHECK-NEXT: %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index
282+ // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
283+ // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
284+ // CHECK-NEXT: }
228285 // CHECK-NEXT: return
229286 %c0 = arith.constant 0 : index
230287 %pad = arith.constant 0.0 : f32
0 commit comments