@@ -113,8 +113,70 @@ module {
113113
114114// CHECK: func.func @entry(%[[ARG0:.*]]: tensor<2x4x8x1x2xbf16>) -> tensor<2x2x8x4xbf16> {
115115// CHECK: vector.transfer_write
116- // CHECK-NOT: %[[vec1:.*]] = vector.transfer_read
117- // CHECK-NOT: %[[vec2:.*]] = vector.transfer_read
118- // CHECK-NOT: %[[vec3:.*]] = vector.transfer_read
119- // CHECK-NOT: %[[vec4:.*]] = vector.contract
120- // CHECK-NOT: vector.transfer_write %[[vec4]]
116+ // CHECK: vector.transfer_read
117+ // CHECK: vector.transfer_read
118+ // CHECK: vector.contract
119+ // CHECK: vector.transfer_write
120+
121+ // -----
122+
123+ #map = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d0 , d2 , d4 , d6 , d3 )>
124+ #map1 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d1 , d2 , d6 , d5 , d3 )>
125+ #map2 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d0 , d1 , d4 , d5 )>
126+ module {
127+ func.func @vectorize_contract_mixed_precision_int (
128+ %arg0: tensor <1 x2 x32 x8 x4 xi8 >, %arg1: tensor <2 x2 x8 x32 x4 xi8 >,
129+ %arg2: tensor <1 x2 x32 x32 xi32 >) -> tensor <1 x2 x32 x32 xi32 > {
130+ %0 = linalg.generic {
131+ indexing_maps = [#map , #map1 , #map2 ],
132+ iterator_types = [" parallel" , " parallel" , " reduction" , " reduction" , " parallel" , " parallel" , " reduction" ]}
133+ ins (%arg0 , %arg1 : tensor <1 x2 x32 x8 x4 xi8 >, tensor <2 x2 x8 x32 x4 xi8 >)
134+ outs (%arg2 : tensor <1 x2 x32 x32 xi32 >) {
135+ ^bb0 (%in: i8 , %in_0: i8 , %out: i32 ):
136+ %0 = arith.extsi %in : i8 to i32
137+ %1 = arith.extsi %in_0 : i8 to i32
138+ %2 = arith.muli %0 , %1 : i32
139+ %3 = arith.addi %out , %2 : i32
140+ linalg.yield %3 : i32
141+ } -> tensor <1 x2 x32 x32 xi32 >
142+ return %0 : tensor <1 x2 x32 x32 xi32 >
143+ }
144+ }
145+
146+ // CHECK-LABEL: @vectorize_contract_mixed_precision_int
147+ // CHECK: vector.transfer_read{{.*}}: tensor<1x2x32x8x4xi8>, vector<1x2x32x8x4xi8>
148+ // CHECK-NOT: vector.broadcast
149+ // CHECK-NOT: vector.transpose
150+ // CHECK: vector.transfer_read{{.*}}: tensor<2x2x8x32x4xi8>, vector<2x2x8x32x4xi8>
151+ // CHECK: vector.transfer_read{{.*}}: tensor<1x2x32x32xi32>, vector<1x2x32x32xi32>
152+ // CHECK-NOT: arith.extsi
153+ // CHECK: vector.contract
154+ // CHECK: vector.transfer_write
155+
156+ // -----
157+
158+ #map = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 )>
159+ #map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d2 , d1 , d3 )>
160+ #map2 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 )>
161+ func.func @vectorize_contract_mixed_precision_float (
162+ %arg0: tensor <256 x128 x2 xbf16 >, %arg1: tensor <128 x256 x2 xbf16 >,
163+ %arg2: tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 > {
164+ %0 = linalg.contract
165+ indexing_maps = [#map , #map1 , #map2 ]
166+ ins (%arg0 , %arg1 : tensor <256 x128 x2 xbf16 >, tensor <128 x256 x2 xbf16 >)
167+ outs (%arg2 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
168+ return %0 : tensor <256 x256 xf32 >
169+ }
170+
171+ // Ensure that mixed precision contraction vectorizes cleanly
172+ // without extra operations and/or dimensions.
173+
174+ // CHECK-LABEL: @vectorize_contract_mixed_precision_float
175+ // CHECK: vector.transfer_read{{.*}}: tensor<256x128x2xbf16>, vector<256x128x2xbf16>
176+ // CHECK-NOT: vector.broadcast
177+ // CHECK-NOT: vector.transpose
178+ // CHECK: vector.transfer_read{{.*}}: tensor<128x256x2xbf16>, vector<128x256x2xbf16>
179+ // CHECK: vector.transfer_read{{.*}}: tensor<256x256xf32>, vector<256x256xf32>
180+ // CHECK-NOT: arith.extf
181+ // CHECK: vector.contract
182+ // CHECK: vector.transfer_write
0 commit comments