Skip to content

Commit bdde3e9

Browse files
authored
[Linalg-Vectorizer] Add pattern for mixed precision vector.contract. (#1067)
This pattern tries to fold arith.ext* ops present in high level mixed precision linalg contraction ops into mixed precision vector.contract.
1 parent baa2bbe commit bdde3e9

File tree

4 files changed

+73
-9
lines changed

4 files changed

+73
-9
lines changed

lib/TPP/Transforms/LinalgVectorize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ struct LinalgVectorize
7878
tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
7979
patterns.add<linalg::CopyVectorizationPattern>(ctx);
8080
vector::populateVectorStepLoweringPatterns(patterns);
81+
vector::populateFoldArithExtensionPatterns(patterns);
8182

8283
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
8384
return signalPassFailure();

lib/TPP/Transforms/Vectorization.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "TPP/Dialect/Xsmm/XsmmUtils.h"
1212
#include "mlir/Dialect/Affine/Utils.h"
1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1415
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1516
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1617
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -37,8 +38,6 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> {
3738

3839
LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
3940
PatternRewriter &rewriter) const override {
40-
if (!linalgOp.hasPureBufferSemantics())
41-
return failure();
4241
if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) ==
4342
xsmm::DataTypeAttr::get(rewriter.getContext(),
4443
xsmm::DataType::BF16) &&
@@ -107,6 +106,7 @@ struct VectorizationPass
107106

108107
void populateCombinePatterns(RewritePatternSet &patterns) {
109108
patterns.add<LinalgToVector<linalg::BatchReduceMatmulOp>,
109+
LinalgToVector<linalg::ContractOp>,
110110
LinalgToVector<linalg::TransposeOp>,
111111
LinalgToVector<linalg::FillOp>>(patterns.getContext());
112112
patterns.add<LinalgGenericToVector>(patterns.getContext());
@@ -117,6 +117,7 @@ struct VectorizationPass
117117
populateCombinePatterns(patterns);
118118
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
119119
vector::populateVectorReductionToContractPatterns(patterns);
120+
vector::populateFoldArithExtensionPatterns(patterns);
120121
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
121122
}
122123
};

test/Passes/linalg-vectorize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func.func @vectorize_contract_mixed_precision_float(
5454
// CHECK-NOT: vector.transpose
5555
// CHECK: vector.transfer_read{{.*}}: tensor<128x256x2xbf16>, vector<128x256x2xbf16>
5656
// CHECK: vector.transfer_read{{.*}}: tensor<256x256xf32>, vector<256x256xf32>
57-
// CHECK-COUNT-2: arith.extf
57+
// CHECK-NOT: arith.extf
5858
// CHECK: vector.contract
5959
// CHECK: vector.transfer_write
6060

@@ -92,7 +92,7 @@ module {
9292
// CHECK-NOT: vector.transpose
9393
// CHECK: vector.transfer_read{{.*}}: tensor<2x2x8x32x4xi8>, vector<2x2x8x32x4xi8>
9494
// CHECK: vector.transfer_read{{.*}}: tensor<1x2x32x32xi32>, vector<1x2x32x32xi32>
95-
// CHECK-COUNT-2: arith.extsi
95+
// CHECK-NOT: arith.extsi
9696
// CHECK: vector.contract
9797
// CHECK: vector.transfer_write
9898

test/Passes/pass-vectorization.mlir

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,70 @@ module {
113113

114114
// CHECK: func.func @entry(%[[ARG0:.*]]: tensor<2x4x8x1x2xbf16>) -> tensor<2x2x8x4xbf16> {
115115
// CHECK: vector.transfer_write
116-
// CHECK-NOT: %[[vec1:.*]] = vector.transfer_read
117-
// CHECK-NOT: %[[vec2:.*]] = vector.transfer_read
118-
// CHECK-NOT: %[[vec3:.*]] = vector.transfer_read
119-
// CHECK-NOT: %[[vec4:.*]] = vector.contract
120-
// CHECK-NOT: vector.transfer_write %[[vec4]]
116+
// CHECK: vector.transfer_read
117+
// CHECK: vector.transfer_read
118+
// CHECK: vector.contract
119+
// CHECK: vector.transfer_write
120+
121+
// -----
122+
123+
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6, d3)>
124+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6, d5, d3)>
125+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)>
126+
module {
127+
func.func @vectorize_contract_mixed_precision_int(
128+
%arg0: tensor<1x2x32x8x4xi8>, %arg1: tensor<2x2x8x32x4xi8>,
129+
%arg2: tensor<1x2x32x32xi32>) -> tensor<1x2x32x32xi32> {
130+
%0 = linalg.generic {
131+
indexing_maps = [#map, #map1, #map2],
132+
iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]}
133+
ins(%arg0, %arg1 : tensor<1x2x32x8x4xi8>, tensor<2x2x8x32x4xi8>)
134+
outs(%arg2 : tensor<1x2x32x32xi32>) {
135+
^bb0(%in: i8, %in_0: i8, %out: i32):
136+
%0 = arith.extsi %in : i8 to i32
137+
%1 = arith.extsi %in_0 : i8 to i32
138+
%2 = arith.muli %0, %1 : i32
139+
%3 = arith.addi %out, %2 : i32
140+
linalg.yield %3 : i32
141+
} -> tensor<1x2x32x32xi32>
142+
return %0 : tensor<1x2x32x32xi32>
143+
}
144+
}
145+
146+
// CHECK-LABEL: @vectorize_contract_mixed_precision_int
147+
// CHECK: vector.transfer_read{{.*}}: tensor<1x2x32x8x4xi8>, vector<1x2x32x8x4xi8>
148+
// CHECK-NOT: vector.broadcast
149+
// CHECK-NOT: vector.transpose
150+
// CHECK: vector.transfer_read{{.*}}: tensor<2x2x8x32x4xi8>, vector<2x2x8x32x4xi8>
151+
// CHECK: vector.transfer_read{{.*}}: tensor<1x2x32x32xi32>, vector<1x2x32x32xi32>
152+
// CHECK-NOT: arith.extsi
153+
// CHECK: vector.contract
154+
// CHECK: vector.transfer_write
155+
156+
// -----
157+
158+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
159+
#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
160+
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
161+
func.func @vectorize_contract_mixed_precision_float(
162+
%arg0: tensor<256x128x2xbf16>, %arg1: tensor<128x256x2xbf16>,
163+
%arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
164+
%0 = linalg.contract
165+
indexing_maps = [#map, #map1, #map2]
166+
ins(%arg0, %arg1 : tensor<256x128x2xbf16>, tensor<128x256x2xbf16>)
167+
outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
168+
return %0 : tensor<256x256xf32>
169+
}
170+
171+
// Ensure that mixed precision contraction vectorizes cleanly
172+
// without extra operations and/or dimensions.
173+
174+
// CHECK-LABEL: @vectorize_contract_mixed_precision_float
175+
// CHECK: vector.transfer_read{{.*}}: tensor<256x128x2xbf16>, vector<256x128x2xbf16>
176+
// CHECK-NOT: vector.broadcast
177+
// CHECK-NOT: vector.transpose
178+
// CHECK: vector.transfer_read{{.*}}: tensor<128x256x2xbf16>, vector<128x256x2xbf16>
179+
// CHECK: vector.transfer_read{{.*}}: tensor<256x256xf32>, vector<256x256xf32>
180+
// CHECK-NOT: arith.extf
181+
// CHECK: vector.contract
182+
// CHECK: vector.transfer_write

0 commit comments

Comments
 (0)