Skip to content

Commit 4a95804

Browse files
authored
[μKernels]: lowering based on m and n tile size (#1068)
This `PR` update a small logic in the micro-kernels lowering based on `m` and `n` tile size.
1 parent bdde3e9 commit 4a95804

File tree

4 files changed

+331
-63
lines changed

4 files changed

+331
-63
lines changed

lib/TPP/Transforms/VectorContractToMicroKernels.cpp

Lines changed: 127 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,13 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
265265
// Retrive the element type (f32 or bf16 or f16)
266266
auto subviewOpAcc =
267267
vectorReadOpAcc.getOperand(0).getDefiningOp<memref::SubViewOp>();
268-
auto subviewOpLhs =
269-
vectorReadOpLhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
270-
268+
auto subviewOpLhs =
269+
vectorReadOpLhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
271270

272271
auto elementType =
273272
(cast<MemRefType>(subviewOpLhs.getType())).getElementType();
274-
auto outsElementType=
275-
(cast<MemRefType>(subviewOpAcc.getType())).getElementType();
273+
auto outsElementType =
274+
(cast<MemRefType>(subviewOpAcc.getType())).getElementType();
276275

277276
// We get target architecture and decide on uKernel lowering using flags
278277
bool avx512 = vnni::utils::hasAVX512();
@@ -409,6 +408,20 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
409408
return rewriter.notifyMatchFailure(
410409
contractOp, "Affine map permutation not supported.");
411410

411+
// Lowering is done based on M and N tile sizes. If M >= N: load all B
412+
// matrix then broadcast A ony-by-one + FMA.
413+
// If N > M: perform opposite. Broadcast A matrix then load B one-by-
414+
// one + FMA.
415+
// Following this kind of lowering, we reduce the register loads by
416+
// stacking the less B loads or less A broadcasts and do the larger B
417+
// loads or A broadcast in a LIFO manner. Finally, it helps in reducing
418+
// the probablity of register spills.
419+
bool mDriven = true;
420+
int64_t nBlock = N / sizeFactor;
421+
422+
if (nBlock > M)
423+
mDriven = false;
424+
412425
rewriter.setInsertionPoint(mForOp);
413426
auto i32Type = rewriter.getIntegerType(32);
414427
auto i16Type = rewriter.getIntegerType(16);
@@ -436,8 +449,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
436449
Value indexOp_B = rewriter.create<arith::ConstantIndexOp>(
437450
reductionForOp.getLoc(), j);
438451
auto valueCRow = rewriter.create<vector::LoadOp>(
439-
reductionForOp.getLoc(), VectorType::get(sizeFactor, outsElementType),
440-
subviewOpAcc, ValueRange{indexOp_A, indexOp_B});
452+
reductionForOp.getLoc(),
453+
VectorType::get(sizeFactor, outsElementType), subviewOpAcc,
454+
ValueRange{indexOp_A, indexOp_B});
441455
auto bitcast_i16 = rewriter.create<vector::BitCastOp>(
442456
reductionForOp.getLoc(), VectorType::get(sizeFactor, i16Type),
443457
valueCRow);
@@ -465,8 +479,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
465479
Value indexOp_B = rewriter.create<arith::ConstantIndexOp>(
466480
reductionForOp.getLoc(), j);
467481
auto valueCRow = rewriter.create<vector::LoadOp>(
468-
reductionForOp.getLoc(), VectorType::get(sizeFactor, outsElementType),
469-
subviewOpAcc, ValueRange{indexOp_A, indexOp_B});
482+
reductionForOp.getLoc(),
483+
VectorType::get(sizeFactor, outsElementType), subviewOpAcc,
484+
ValueRange{indexOp_A, indexOp_B});
470485
auto f32CVector = rewriter.create<arith::ExtFOp>(
471486
reductionForOp.getLoc(),
472487
VectorType::get({8}, rewriter.getF32Type()),
@@ -484,8 +499,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
484499
Value indexOp_B = rewriter.create<arith::ConstantIndexOp>(
485500
reductionForOp.getLoc(), j);
486501
auto valueCRow = rewriter.create<vector::LoadOp>(
487-
reductionForOp.getLoc(), VectorType::get(sizeFactor, outsElementType),
488-
subviewOpAcc, ValueRange{indexOp_A, indexOp_B});
502+
reductionForOp.getLoc(),
503+
VectorType::get(sizeFactor, outsElementType), subviewOpAcc,
504+
ValueRange{indexOp_A, indexOp_B});
489505
loopItrArgs.push_back(valueCRow);
490506
}
491507
}
@@ -559,8 +575,8 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
559575
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1),
560576
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
561577

562-
// uKernel lowering for f32 type. Target: avx512 instructions
563-
if (isF32 && avx512) {
578+
// uKernel lowering for f32 type. M -> N.
579+
if (isF32 && mDriven) {
564580
// Load elements of B matrix and store in a DS
565581
for (int j = 0; j < N; j = j + sizeFactor) {
566582
Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(
@@ -606,8 +622,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
606622
evenFMAs.push_back(oddFMAs[j + (i * (N / sizeFactor))]);
607623
}
608624
}
609-
} else if (isF32 && avx2) { // uKernel lowering for f32 type.
610-
// Target: avx2 instructions
625+
} else if (isF32 && !mDriven) { // N -> M.
611626
// Load elements of A matrix and store in a DS
612627
for (int i = 0; i < M; i++) {
613628
Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(
@@ -650,58 +665,105 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
650665
// bf16 type + avx512. uKernel lowering for machines like
651666
// cpx (zen5) to target avx512bf16dp.
652667
if (bf16dp && isBF16) {
653-
// Load elements of B matrix and store in a DS
654-
for (int j = 0; j < N; j = j + sizeFactor) {
655-
Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(
656-
reductionForOp.getLoc(), j);
657-
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
658-
kForOp.getLoc(), VectorType::get(32, elementType),
659-
rhsClone->getResult(0),
660-
ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
661-
indexOp_c0});
662-
matf32.push_back(valueRow);
663-
}
664668

665-
// Load elements of A matrix, do FMA, and store in a DS
666-
for (int i = 0; i < M; i++) {
667-
Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(
668-
reductionForOp.getLoc(), i);
669-
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
670-
kForOp.getLoc(), VectorType::get({vnni}, elementType),
671-
lhsClone->getResult(0),
672-
ValueRange{indexOp_c0, indexOp_i, indexOp_c0,
673-
indexOp_c0});
674-
auto bitcastValue_i32 =
675-
rewriterNewKForOp.create<vector::BitCastOp>(
676-
kForOp.getLoc(),
677-
VectorType::get({1},
678-
rewriterNewKForOp.getI32Type()),
679-
valueRow);
680-
auto bcst_i32 =
681-
rewriterNewKForOp.create<vector::BroadcastOp>(
682-
kForOp.getLoc(),
683-
VectorType::get(sizeFactor,
684-
rewriterNewKForOp.getI32Type()),
685-
bitcastValue_i32);
686-
auto valuef32 = rewriterNewKForOp.create<vector::BitCastOp>(
687-
kForOp.getLoc(),
688-
VectorType::get(32, rewriterNewKForOp.getBF16Type()),
689-
bcst_i32);
669+
if (mDriven) { // M -> N
670+
// Load elements of B matrix and store in a DS
671+
for (int j = 0; j < N; j = j + sizeFactor) {
672+
Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(
673+
reductionForOp.getLoc(), j);
674+
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
675+
kForOp.getLoc(), VectorType::get(32, elementType),
676+
rhsClone->getResult(0),
677+
ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
678+
indexOp_c0});
679+
matf32.push_back(valueRow);
680+
}
681+
682+
// Load elements of A matrix, do FMA, and store in a DS
683+
for (int i = 0; i < M; i++) {
684+
Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(
685+
reductionForOp.getLoc(), i);
686+
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
687+
kForOp.getLoc(), VectorType::get({vnni}, elementType),
688+
lhsClone->getResult(0),
689+
ValueRange{indexOp_c0, indexOp_i, indexOp_c0,
690+
indexOp_c0});
691+
auto bitcastValue_i32 =
692+
rewriterNewKForOp.create<vector::BitCastOp>(
693+
kForOp.getLoc(), VectorType::get({1}, i32Type),
694+
valueRow);
695+
auto bcst_i32 =
696+
rewriterNewKForOp.create<vector::BroadcastOp>(
697+
kForOp.getLoc(),
698+
VectorType::get(sizeFactor, i32Type),
699+
bitcastValue_i32);
700+
auto valuef32 =
701+
rewriterNewKForOp.create<vector::BitCastOp>(
702+
kForOp.getLoc(),
703+
VectorType::get(32,
704+
rewriterNewKForOp.getBF16Type()),
705+
bcst_i32);
706+
for (int j = 0; j < (N / sizeFactor); j++) {
707+
auto dp = rewriter.create<mlir::x86vector::DotBF16Op>(
708+
kForOp.getLoc(), dstType,
709+
iterArgsNewKForOp[i + (j * M)], valuef32,
710+
matf32[j]);
711+
oddFMAs.push_back(dp);
712+
}
713+
}
714+
715+
// Re-arrange the stored FMAs in order of N -> M.
716+
// We load C matrix with N -> M. For example: c[0][0],
717+
// c[1][0] We do dp as M -> N order { [0][0], [0][16] ...}.
718+
// So, shuffling the M -> N to N -> M order
690719
for (int j = 0; j < (N / sizeFactor); j++) {
691-
auto dp = rewriter.create<mlir::x86vector::DotBF16Op>(
692-
kForOp.getLoc(), dstType,
693-
iterArgsNewKForOp[i + (j * M)], valuef32, matf32[j]);
694-
oddFMAs.push_back(dp);
720+
for (int i = 0; i < M; i++) {
721+
evenFMAs.push_back(oddFMAs[j + (i * (N / sizeFactor))]);
722+
}
695723
}
696-
}
697724

698-
// Re-arrange the stored FMAs in order of N -> M.
699-
// We load C matrix with N -> M. For example: c[0][0], c[1][0]
700-
// We do dp as M -> N order { [0][0], [0][16] ...}. So,
701-
// shuffling the M -> N to N -> M order
702-
for (int j = 0; j < (N / sizeFactor); j++) {
725+
} else { // N -> M
703726
for (int i = 0; i < M; i++) {
704-
evenFMAs.push_back(oddFMAs[j + (i * (N / sizeFactor))]);
727+
Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(
728+
reductionForOp.getLoc(), i);
729+
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
730+
kForOp.getLoc(), VectorType::get({vnni}, elementType),
731+
lhsClone->getResult(0),
732+
ValueRange{indexOp_c0, indexOp_i, indexOp_c0,
733+
indexOp_c0});
734+
auto bitcastValue_i32 =
735+
rewriterNewKForOp.create<vector::BitCastOp>(
736+
kForOp.getLoc(), VectorType::get({1}, i32Type),
737+
valueRow);
738+
auto bcst_i32 =
739+
rewriterNewKForOp.create<vector::BroadcastOp>(
740+
kForOp.getLoc(),
741+
VectorType::get(sizeFactor, i32Type),
742+
bitcastValue_i32);
743+
auto valuef32 =
744+
rewriterNewKForOp.create<vector::BitCastOp>(
745+
kForOp.getLoc(),
746+
VectorType::get(32,
747+
rewriterNewKForOp.getBF16Type()),
748+
bcst_i32);
749+
matf32.push_back(valuef32);
750+
}
751+
752+
for (int j = 0, k = 0; j < N; j = j + sizeFactor) {
753+
Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(
754+
reductionForOp.getLoc(), j);
755+
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
756+
kForOp.getLoc(), VectorType::get(32, elementType),
757+
rhsClone->getResult(0),
758+
ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
759+
indexOp_c0});
760+
for (int i = 0; i < M; i++) {
761+
auto dp = rewriter.create<mlir::x86vector::DotBF16Op>(
762+
kForOp.getLoc(), dstType, iterArgsNewKForOp[k],
763+
matf32[i], valueRow);
764+
k++;
765+
evenFMAs.push_back(dp);
766+
}
705767
}
706768
}
707769
}
@@ -840,7 +902,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
840902

841903
// uKernel lowering for AVX2 machines
842904
// Target: (a) f16 and bf16 for srf kind of machines
843-
// (b) bf16 fallback + avx2 instructions
905+
// (b) bf16 fallback + avx2 instructions.
906+
// TODO: update lowering based on M & N. Now it is
907+
// default to M -> N
844908
if (srf || (fallback && avx2 && !avx512)) {
845909
// Load odd elements of A Matrix and store in a DS
846910
for (int i = 0; i < M; i++) {

test/Passes/uKernels/avx2/pass-vector-contract-to-FMAs.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,63 @@ module {
5454

5555
// -----
5656

57+
#map_nm = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
58+
#map_nm1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
59+
#map_nm2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
60+
module {
61+
func.func @opt_register_4x3(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x24xf32>, %arg2: memref<4x24xf32>) -> memref<4x24xf32> {
62+
%cst = arith.constant 0.000000e+00 : f32
63+
%c0 = arith.constant 0 : index
64+
%c4 = arith.constant 4 : index
65+
%c24 = arith.constant 24 : index
66+
%c1 = arith.constant 1 : index
67+
%c32 = arith.constant 32 : index
68+
scf.for %arg3 = %c0 to %c4 step %c4 {
69+
scf.for %arg4 = %c0 to %c24 step %c24 {
70+
%subview = memref.subview %arg2[%arg3, %arg4] [4, 24] [1, 1] : memref<4x24xf32> to memref<4x24xf32, strided<[24, 1], offset: ?>>
71+
%0 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x24xf32, strided<[24, 1], offset: ?>>, vector<4x24xf32>
72+
%1 = scf.for %arg5 = %c0 to %c1 step %c1 iter_args(%arg6 = %0) -> (vector<4x24xf32>) {
73+
%2 = scf.for %arg7 = %c0 to %c32 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x24xf32>) {
74+
%subview_0 = memref.subview %arg0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<1x4x32xf32> to memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>
75+
%subview_1 = memref.subview %arg1[%arg5, %arg7, %arg4] [1, 1, 24] [1, 1, 1] : memref<1x32x24xf32> to memref<1x1x24xf32, strided<[768, 24, 1], offset: ?>>
76+
%3 = vector.transfer_read %subview_0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>, vector<1x4x1xf32>
77+
%4 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x24xf32, strided<[768, 24, 1], offset: ?>>, vector<1x1x24xf32>
78+
%5 = vector.contract {indexing_maps = [#map_nm, #map_nm1, #map_nm2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %arg8 : vector<1x4x1xf32>, vector<1x1x24xf32> into vector<4x24xf32>
79+
scf.yield %5 : vector<4x24xf32>
80+
}
81+
scf.yield %2 : vector<4x24xf32>
82+
}
83+
vector.transfer_write %1, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<4x24xf32>, memref<4x24xf32, strided<[24, 1], offset: ?>>
84+
}
85+
}
86+
return %arg2 : memref<4x24xf32>
87+
}
88+
}
89+
90+
// CHECK-LABEL: func.func @opt_register_4x3
91+
// CHECK: scf.for
92+
// CHECK: vector.broadcast
93+
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
94+
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
95+
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
96+
// CHECK-NEXT: vector.load
97+
// CHECK-NEXT: vector.broadcast
98+
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
99+
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
100+
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
101+
// CHECK-NEXT: vector.load
102+
// CHECK-NEXT: vector.broadcast
103+
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
104+
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
105+
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
106+
// CHECK-NEXT: vector.load
107+
// CHECK-NEXT: vector.broadcast
108+
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
109+
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
110+
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
111+
112+
// -----
113+
57114
#no_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
58115
#no_map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
59116
#no_map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>

0 commit comments

Comments
 (0)