Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 127 additions & 63 deletions lib/TPP/Transforms/VectorContractToMicroKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,13 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
// Retrive the element type (f32 or bf16 or f16)
auto subviewOpAcc =
vectorReadOpAcc.getOperand(0).getDefiningOp<memref::SubViewOp>();
auto subviewOpLhs =
vectorReadOpLhs.getOperand(0).getDefiningOp<memref::SubViewOp>();

auto subviewOpLhs =
vectorReadOpLhs.getOperand(0).getDefiningOp<memref::SubViewOp>();

auto elementType =
(cast<MemRefType>(subviewOpLhs.getType())).getElementType();
auto outsElementType=
(cast<MemRefType>(subviewOpAcc.getType())).getElementType();
auto outsElementType =
(cast<MemRefType>(subviewOpAcc.getType())).getElementType();

// We get target architecture and decide on uKernel lowering using flags
bool avx512 = vnni::utils::hasAVX512();
Expand Down Expand Up @@ -409,6 +408,20 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
return rewriter.notifyMatchFailure(
contractOp, "Affine map permutation not supported.");

// Lowering is done based on M and N tile sizes. If M >= N: load all B
// matrix then broadcast A ony-by-one + FMA.
// If N > M: perform opposite. Broadcast A matrix then load B one-by-
// one + FMA.
// Following this kind of lowering, we reduce the register loads by
// stacking the less B loads or less A broadcasts and do the larger B
// loads or A broadcast in a LIFO manner. Finally, it helps in reducing
// the probablity of register spills.
bool mDriven = true;
int64_t nBlock = N / sizeFactor;

if (nBlock > M)
mDriven = false;

rewriter.setInsertionPoint(mForOp);
auto i32Type = rewriter.getIntegerType(32);
auto i16Type = rewriter.getIntegerType(16);
Expand Down Expand Up @@ -436,8 +449,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
Value indexOp_B = rewriter.create<arith::ConstantIndexOp>(
reductionForOp.getLoc(), j);
auto valueCRow = rewriter.create<vector::LoadOp>(
reductionForOp.getLoc(), VectorType::get(sizeFactor, outsElementType),
subviewOpAcc, ValueRange{indexOp_A, indexOp_B});
reductionForOp.getLoc(),
VectorType::get(sizeFactor, outsElementType), subviewOpAcc,
ValueRange{indexOp_A, indexOp_B});
auto bitcast_i16 = rewriter.create<vector::BitCastOp>(
reductionForOp.getLoc(), VectorType::get(sizeFactor, i16Type),
valueCRow);
Expand Down Expand Up @@ -465,8 +479,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
Value indexOp_B = rewriter.create<arith::ConstantIndexOp>(
reductionForOp.getLoc(), j);
auto valueCRow = rewriter.create<vector::LoadOp>(
reductionForOp.getLoc(), VectorType::get(sizeFactor, outsElementType),
subviewOpAcc, ValueRange{indexOp_A, indexOp_B});
reductionForOp.getLoc(),
VectorType::get(sizeFactor, outsElementType), subviewOpAcc,
ValueRange{indexOp_A, indexOp_B});
auto f32CVector = rewriter.create<arith::ExtFOp>(
reductionForOp.getLoc(),
VectorType::get({8}, rewriter.getF32Type()),
Expand All @@ -484,8 +499,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
Value indexOp_B = rewriter.create<arith::ConstantIndexOp>(
reductionForOp.getLoc(), j);
auto valueCRow = rewriter.create<vector::LoadOp>(
reductionForOp.getLoc(), VectorType::get(sizeFactor, outsElementType),
subviewOpAcc, ValueRange{indexOp_A, indexOp_B});
reductionForOp.getLoc(),
VectorType::get(sizeFactor, outsElementType), subviewOpAcc,
ValueRange{indexOp_A, indexOp_B});
loopItrArgs.push_back(valueCRow);
}
}
Expand Down Expand Up @@ -559,8 +575,8 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};

// uKernel lowering for f32 type. Target: avx512 instructions
if (isF32 && avx512) {
// uKernel lowering for f32 type. M -> N.
if (isF32 && mDriven) {
// Load elements of B matrix and store in a DS
for (int j = 0; j < N; j = j + sizeFactor) {
Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(
Expand Down Expand Up @@ -606,8 +622,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
evenFMAs.push_back(oddFMAs[j + (i * (N / sizeFactor))]);
}
}
} else if (isF32 && avx2) { // uKernel lowering for f32 type.
// Target: avx2 instructions
} else if (isF32 && !mDriven) { // N -> M.
// Load elements of A matrix and store in a DS
for (int i = 0; i < M; i++) {
Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(
Expand Down Expand Up @@ -650,58 +665,105 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
// bf16 type + avx512. uKernel lowering for machines like
// cpx (zen5) to target avx512bf16dp.
if (bf16dp && isBF16) {
// Load elements of B matrix and store in a DS
for (int j = 0; j < N; j = j + sizeFactor) {
Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(
reductionForOp.getLoc(), j);
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
kForOp.getLoc(), VectorType::get(32, elementType),
rhsClone->getResult(0),
ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
indexOp_c0});
matf32.push_back(valueRow);
}

// Load elements of A matrix, do FMA, and store in a DS
for (int i = 0; i < M; i++) {
Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(
reductionForOp.getLoc(), i);
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
kForOp.getLoc(), VectorType::get({vnni}, elementType),
lhsClone->getResult(0),
ValueRange{indexOp_c0, indexOp_i, indexOp_c0,
indexOp_c0});
auto bitcastValue_i32 =
rewriterNewKForOp.create<vector::BitCastOp>(
kForOp.getLoc(),
VectorType::get({1},
rewriterNewKForOp.getI32Type()),
valueRow);
auto bcst_i32 =
rewriterNewKForOp.create<vector::BroadcastOp>(
kForOp.getLoc(),
VectorType::get(sizeFactor,
rewriterNewKForOp.getI32Type()),
bitcastValue_i32);
auto valuef32 = rewriterNewKForOp.create<vector::BitCastOp>(
kForOp.getLoc(),
VectorType::get(32, rewriterNewKForOp.getBF16Type()),
bcst_i32);
if (mDriven) { // M -> N
// Load elements of B matrix and store in a DS
for (int j = 0; j < N; j = j + sizeFactor) {
Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(
reductionForOp.getLoc(), j);
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
kForOp.getLoc(), VectorType::get(32, elementType),
rhsClone->getResult(0),
ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
indexOp_c0});
matf32.push_back(valueRow);
}

// Load elements of A matrix, do FMA, and store in a DS
for (int i = 0; i < M; i++) {
Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(
reductionForOp.getLoc(), i);
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
kForOp.getLoc(), VectorType::get({vnni}, elementType),
lhsClone->getResult(0),
ValueRange{indexOp_c0, indexOp_i, indexOp_c0,
indexOp_c0});
auto bitcastValue_i32 =
rewriterNewKForOp.create<vector::BitCastOp>(
kForOp.getLoc(), VectorType::get({1}, i32Type),
valueRow);
auto bcst_i32 =
rewriterNewKForOp.create<vector::BroadcastOp>(
kForOp.getLoc(),
VectorType::get(sizeFactor, i32Type),
bitcastValue_i32);
auto valuef32 =
rewriterNewKForOp.create<vector::BitCastOp>(
kForOp.getLoc(),
VectorType::get(32,
rewriterNewKForOp.getBF16Type()),
bcst_i32);
for (int j = 0; j < (N / sizeFactor); j++) {
auto dp = rewriter.create<mlir::x86vector::DotBF16Op>(
kForOp.getLoc(), dstType,
iterArgsNewKForOp[i + (j * M)], valuef32,
matf32[j]);
oddFMAs.push_back(dp);
}
}

// Re-arrange the stored FMAs in order of N -> M.
// We load C matrix with N -> M. For example: c[0][0],
// c[1][0] We do dp as M -> N order { [0][0], [0][16] ...}.
// So, shuffling the M -> N to N -> M order
for (int j = 0; j < (N / sizeFactor); j++) {
auto dp = rewriter.create<mlir::x86vector::DotBF16Op>(
kForOp.getLoc(), dstType,
iterArgsNewKForOp[i + (j * M)], valuef32, matf32[j]);
oddFMAs.push_back(dp);
for (int i = 0; i < M; i++) {
evenFMAs.push_back(oddFMAs[j + (i * (N / sizeFactor))]);
}
}
}

// Re-arrange the stored FMAs in order of N -> M.
// We load C matrix with N -> M. For example: c[0][0], c[1][0]
// We do dp as M -> N order { [0][0], [0][16] ...}. So,
// shuffling the M -> N to N -> M order
for (int j = 0; j < (N / sizeFactor); j++) {
} else { // N -> M
for (int i = 0; i < M; i++) {
evenFMAs.push_back(oddFMAs[j + (i * (N / sizeFactor))]);
Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(
reductionForOp.getLoc(), i);
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
kForOp.getLoc(), VectorType::get({vnni}, elementType),
lhsClone->getResult(0),
ValueRange{indexOp_c0, indexOp_i, indexOp_c0,
indexOp_c0});
auto bitcastValue_i32 =
rewriterNewKForOp.create<vector::BitCastOp>(
kForOp.getLoc(), VectorType::get({1}, i32Type),
valueRow);
auto bcst_i32 =
rewriterNewKForOp.create<vector::BroadcastOp>(
kForOp.getLoc(),
VectorType::get(sizeFactor, i32Type),
bitcastValue_i32);
auto valuef32 =
rewriterNewKForOp.create<vector::BitCastOp>(
kForOp.getLoc(),
VectorType::get(32,
rewriterNewKForOp.getBF16Type()),
bcst_i32);
matf32.push_back(valuef32);
}

for (int j = 0, k = 0; j < N; j = j + sizeFactor) {
Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(
reductionForOp.getLoc(), j);
auto valueRow = rewriterNewKForOp.create<vector::LoadOp>(
kForOp.getLoc(), VectorType::get(32, elementType),
rhsClone->getResult(0),
ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
indexOp_c0});
for (int i = 0; i < M; i++) {
auto dp = rewriter.create<mlir::x86vector::DotBF16Op>(
kForOp.getLoc(), dstType, iterArgsNewKForOp[k],
matf32[i], valueRow);
k++;
evenFMAs.push_back(dp);
}
}
}
}
Expand Down Expand Up @@ -840,7 +902,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {

// uKernel lowering for AVX2 machines
// Target: (a) f16 and bf16 for srf kind of machines
// (b) bf16 fallback + avx2 instructions
// (b) bf16 fallback + avx2 instructions.
// TODO: update lowering based on M & N. Now it is
// default to M -> N
if (srf || (fallback && avx2 && !avx512)) {
// Load odd elements of A Matrix and store in a DS
for (int i = 0; i < M; i++) {
Expand Down
57 changes: 57 additions & 0 deletions test/Passes/uKernels/avx2/pass-vector-contract-to-FMAs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,63 @@ module {

// -----

#map_nm = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map_nm1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map_nm2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
module {
func.func @opt_register_4x3(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x24xf32>, %arg2: memref<4x24xf32>) -> memref<4x24xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c24 = arith.constant 24 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
scf.for %arg3 = %c0 to %c4 step %c4 {
scf.for %arg4 = %c0 to %c24 step %c24 {
%subview = memref.subview %arg2[%arg3, %arg4] [4, 24] [1, 1] : memref<4x24xf32> to memref<4x24xf32, strided<[24, 1], offset: ?>>
%0 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x24xf32, strided<[24, 1], offset: ?>>, vector<4x24xf32>
%1 = scf.for %arg5 = %c0 to %c1 step %c1 iter_args(%arg6 = %0) -> (vector<4x24xf32>) {
%2 = scf.for %arg7 = %c0 to %c32 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x24xf32>) {
%subview_0 = memref.subview %arg0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<1x4x32xf32> to memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>
%subview_1 = memref.subview %arg1[%arg5, %arg7, %arg4] [1, 1, 24] [1, 1, 1] : memref<1x32x24xf32> to memref<1x1x24xf32, strided<[768, 24, 1], offset: ?>>
%3 = vector.transfer_read %subview_0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>, vector<1x4x1xf32>
%4 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x24xf32, strided<[768, 24, 1], offset: ?>>, vector<1x1x24xf32>
%5 = vector.contract {indexing_maps = [#map_nm, #map_nm1, #map_nm2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %arg8 : vector<1x4x1xf32>, vector<1x1x24xf32> into vector<4x24xf32>
scf.yield %5 : vector<4x24xf32>
}
scf.yield %2 : vector<4x24xf32>
}
vector.transfer_write %1, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<4x24xf32>, memref<4x24xf32, strided<[24, 1], offset: ?>>
}
}
return %arg2 : memref<4x24xf32>
}
}

// CHECK-LABEL: func.func @opt_register_4x3
// CHECK: scf.for
// CHECK: vector.broadcast
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
// CHECK-NEXT: vector.load
// CHECK-NEXT: vector.broadcast
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
// CHECK-NEXT: vector.load
// CHECK-NEXT: vector.broadcast
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
// CHECK-NEXT: vector.load
// CHECK-NEXT: vector.broadcast
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>
// CHECK-NEXT: vector.fma{{.*}}vector<8xf32>

// -----

#no_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#no_map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#no_map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
Expand Down
Loading