@@ -265,14 +265,13 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
265265 // Retrive the element type (f32 or bf16 or f16)
266266 auto subviewOpAcc =
267267 vectorReadOpAcc.getOperand (0 ).getDefiningOp <memref::SubViewOp>();
268- auto subviewOpLhs =
269- vectorReadOpLhs.getOperand (0 ).getDefiningOp <memref::SubViewOp>();
270-
268+ auto subviewOpLhs =
269+ vectorReadOpLhs.getOperand (0 ).getDefiningOp <memref::SubViewOp>();
271270
272271 auto elementType =
273272 (cast<MemRefType>(subviewOpLhs.getType ())).getElementType ();
274- auto outsElementType=
275- (cast<MemRefType>(subviewOpAcc.getType ())).getElementType ();
273+ auto outsElementType =
274+ (cast<MemRefType>(subviewOpAcc.getType ())).getElementType ();
276275
277276 // We get target architecture and decide on uKernel lowering using flags
278277 bool avx512 = vnni::utils::hasAVX512 ();
@@ -409,6 +408,20 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
409408 return rewriter.notifyMatchFailure (
410409 contractOp, " Affine map permutation not supported." );
411410
411+ // Lowering is done based on M and N tile sizes. If M >= N: load all B
412+ // matrix then broadcast A ony-by-one + FMA.
413+ // If N > M: perform opposite. Broadcast A matrix then load B one-by-
414+ // one + FMA.
415+ // Following this kind of lowering, we reduce the register loads by
416+ // stacking the less B loads or less A broadcasts and do the larger B
417+ // loads or A broadcast in a LIFO manner. Finally, it helps in reducing
418+ // the probablity of register spills.
419+ bool mDriven = true ;
420+ int64_t nBlock = N / sizeFactor;
421+
422+ if (nBlock > M)
423+ mDriven = false ;
424+
412425 rewriter.setInsertionPoint (mForOp );
413426 auto i32Type = rewriter.getIntegerType (32 );
414427 auto i16Type = rewriter.getIntegerType (16 );
@@ -436,8 +449,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
436449 Value indexOp_B = rewriter.create <arith::ConstantIndexOp>(
437450 reductionForOp.getLoc (), j);
438451 auto valueCRow = rewriter.create <vector::LoadOp>(
439- reductionForOp.getLoc (), VectorType::get (sizeFactor, outsElementType),
440- subviewOpAcc, ValueRange{indexOp_A, indexOp_B});
452+ reductionForOp.getLoc (),
453+ VectorType::get (sizeFactor, outsElementType), subviewOpAcc,
454+ ValueRange{indexOp_A, indexOp_B});
441455 auto bitcast_i16 = rewriter.create <vector::BitCastOp>(
442456 reductionForOp.getLoc (), VectorType::get (sizeFactor, i16Type),
443457 valueCRow);
@@ -465,8 +479,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
465479 Value indexOp_B = rewriter.create <arith::ConstantIndexOp>(
466480 reductionForOp.getLoc (), j);
467481 auto valueCRow = rewriter.create <vector::LoadOp>(
468- reductionForOp.getLoc (), VectorType::get (sizeFactor, outsElementType),
469- subviewOpAcc, ValueRange{indexOp_A, indexOp_B});
482+ reductionForOp.getLoc (),
483+ VectorType::get (sizeFactor, outsElementType), subviewOpAcc,
484+ ValueRange{indexOp_A, indexOp_B});
470485 auto f32CVector = rewriter.create <arith::ExtFOp>(
471486 reductionForOp.getLoc (),
472487 VectorType::get ({8 }, rewriter.getF32Type ()),
@@ -484,8 +499,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
484499 Value indexOp_B = rewriter.create <arith::ConstantIndexOp>(
485500 reductionForOp.getLoc (), j);
486501 auto valueCRow = rewriter.create <vector::LoadOp>(
487- reductionForOp.getLoc (), VectorType::get (sizeFactor, outsElementType),
488- subviewOpAcc, ValueRange{indexOp_A, indexOp_B});
502+ reductionForOp.getLoc (),
503+ VectorType::get (sizeFactor, outsElementType), subviewOpAcc,
504+ ValueRange{indexOp_A, indexOp_B});
489505 loopItrArgs.push_back (valueCRow);
490506 }
491507 }
@@ -559,8 +575,8 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
559575 rewriter.getIndexAttr (1 ), rewriter.getIndexAttr (1 ),
560576 rewriter.getIndexAttr (1 ), rewriter.getIndexAttr (1 )};
561577
562- // uKernel lowering for f32 type. Target: avx512 instructions
563- if (isF32 && avx512 ) {
578+ // uKernel lowering for f32 type. M -> N.
579+ if (isF32 && mDriven ) {
564580 // Load elements of B matrix and store in a DS
565581 for (int j = 0 ; j < N; j = j + sizeFactor) {
566582 Value indexOp_j = rewriter.create <arith::ConstantIndexOp>(
@@ -606,8 +622,7 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
606622 evenFMAs.push_back (oddFMAs[j + (i * (N / sizeFactor))]);
607623 }
608624 }
609- } else if (isF32 && avx2) { // uKernel lowering for f32 type.
610- // Target: avx2 instructions
625+ } else if (isF32 && !mDriven ) { // N -> M.
611626 // Load elements of A matrix and store in a DS
612627 for (int i = 0 ; i < M; i++) {
613628 Value indexOp_i = rewriter.create <arith::ConstantIndexOp>(
@@ -650,58 +665,105 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
650665 // bf16 type + avx512. uKernel lowering for machines like
651666 // cpx (zen5) to target avx512bf16dp.
652667 if (bf16dp && isBF16) {
653- // Load elements of B matrix and store in a DS
654- for (int j = 0 ; j < N; j = j + sizeFactor) {
655- Value indexOp_j = rewriter.create <arith::ConstantIndexOp>(
656- reductionForOp.getLoc (), j);
657- auto valueRow = rewriterNewKForOp.create <vector::LoadOp>(
658- kForOp .getLoc (), VectorType::get (32 , elementType),
659- rhsClone->getResult (0 ),
660- ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
661- indexOp_c0});
662- matf32.push_back (valueRow);
663- }
664668
665- // Load elements of A matrix, do FMA, and store in a DS
666- for (int i = 0 ; i < M; i++) {
667- Value indexOp_i = rewriter.create <arith::ConstantIndexOp>(
668- reductionForOp.getLoc (), i);
669- auto valueRow = rewriterNewKForOp.create <vector::LoadOp>(
670- kForOp .getLoc (), VectorType::get ({vnni}, elementType),
671- lhsClone->getResult (0 ),
672- ValueRange{indexOp_c0, indexOp_i, indexOp_c0,
673- indexOp_c0});
674- auto bitcastValue_i32 =
675- rewriterNewKForOp.create <vector::BitCastOp>(
676- kForOp .getLoc (),
677- VectorType::get ({1 },
678- rewriterNewKForOp.getI32Type ()),
679- valueRow);
680- auto bcst_i32 =
681- rewriterNewKForOp.create <vector::BroadcastOp>(
682- kForOp .getLoc (),
683- VectorType::get (sizeFactor,
684- rewriterNewKForOp.getI32Type ()),
685- bitcastValue_i32);
686- auto valuef32 = rewriterNewKForOp.create <vector::BitCastOp>(
687- kForOp .getLoc (),
688- VectorType::get (32 , rewriterNewKForOp.getBF16Type ()),
689- bcst_i32);
669+ if (mDriven ) { // M -> N
670+ // Load elements of B matrix and store in a DS
671+ for (int j = 0 ; j < N; j = j + sizeFactor) {
672+ Value indexOp_j = rewriter.create <arith::ConstantIndexOp>(
673+ reductionForOp.getLoc (), j);
674+ auto valueRow = rewriterNewKForOp.create <vector::LoadOp>(
675+ kForOp .getLoc (), VectorType::get (32 , elementType),
676+ rhsClone->getResult (0 ),
677+ ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
678+ indexOp_c0});
679+ matf32.push_back (valueRow);
680+ }
681+
682+ // Load elements of A matrix, do FMA, and store in a DS
683+ for (int i = 0 ; i < M; i++) {
684+ Value indexOp_i = rewriter.create <arith::ConstantIndexOp>(
685+ reductionForOp.getLoc (), i);
686+ auto valueRow = rewriterNewKForOp.create <vector::LoadOp>(
687+ kForOp .getLoc (), VectorType::get ({vnni}, elementType),
688+ lhsClone->getResult (0 ),
689+ ValueRange{indexOp_c0, indexOp_i, indexOp_c0,
690+ indexOp_c0});
691+ auto bitcastValue_i32 =
692+ rewriterNewKForOp.create <vector::BitCastOp>(
693+ kForOp .getLoc (), VectorType::get ({1 }, i32Type),
694+ valueRow);
695+ auto bcst_i32 =
696+ rewriterNewKForOp.create <vector::BroadcastOp>(
697+ kForOp .getLoc (),
698+ VectorType::get (sizeFactor, i32Type),
699+ bitcastValue_i32);
700+ auto valuef32 =
701+ rewriterNewKForOp.create <vector::BitCastOp>(
702+ kForOp .getLoc (),
703+ VectorType::get (32 ,
704+ rewriterNewKForOp.getBF16Type ()),
705+ bcst_i32);
706+ for (int j = 0 ; j < (N / sizeFactor); j++) {
707+ auto dp = rewriter.create <mlir::x86vector::DotBF16Op>(
708+ kForOp .getLoc (), dstType,
709+ iterArgsNewKForOp[i + (j * M)], valuef32,
710+ matf32[j]);
711+ oddFMAs.push_back (dp);
712+ }
713+ }
714+
715+ // Re-arrange the stored FMAs in order of N -> M.
716+ // We load C matrix with N -> M. For example: c[0][0],
717+ // c[1][0] We do dp as M -> N order { [0][0], [0][16] ...}.
718+ // So, shuffling the M -> N to N -> M order
690719 for (int j = 0 ; j < (N / sizeFactor); j++) {
691- auto dp = rewriter.create <mlir::x86vector::DotBF16Op>(
692- kForOp .getLoc (), dstType,
693- iterArgsNewKForOp[i + (j * M)], valuef32, matf32[j]);
694- oddFMAs.push_back (dp);
720+ for (int i = 0 ; i < M; i++) {
721+ evenFMAs.push_back (oddFMAs[j + (i * (N / sizeFactor))]);
722+ }
695723 }
696- }
697724
698- // Re-arrange the stored FMAs in order of N -> M.
699- // We load C matrix with N -> M. For example: c[0][0], c[1][0]
700- // We do dp as M -> N order { [0][0], [0][16] ...}. So,
701- // shuffling the M -> N to N -> M order
702- for (int j = 0 ; j < (N / sizeFactor); j++) {
725+ } else { // N -> M
703726 for (int i = 0 ; i < M; i++) {
704- evenFMAs.push_back (oddFMAs[j + (i * (N / sizeFactor))]);
727+ Value indexOp_i = rewriter.create <arith::ConstantIndexOp>(
728+ reductionForOp.getLoc (), i);
729+ auto valueRow = rewriterNewKForOp.create <vector::LoadOp>(
730+ kForOp .getLoc (), VectorType::get ({vnni}, elementType),
731+ lhsClone->getResult (0 ),
732+ ValueRange{indexOp_c0, indexOp_i, indexOp_c0,
733+ indexOp_c0});
734+ auto bitcastValue_i32 =
735+ rewriterNewKForOp.create <vector::BitCastOp>(
736+ kForOp .getLoc (), VectorType::get ({1 }, i32Type),
737+ valueRow);
738+ auto bcst_i32 =
739+ rewriterNewKForOp.create <vector::BroadcastOp>(
740+ kForOp .getLoc (),
741+ VectorType::get (sizeFactor, i32Type),
742+ bitcastValue_i32);
743+ auto valuef32 =
744+ rewriterNewKForOp.create <vector::BitCastOp>(
745+ kForOp .getLoc (),
746+ VectorType::get (32 ,
747+ rewriterNewKForOp.getBF16Type ()),
748+ bcst_i32);
749+ matf32.push_back (valuef32);
750+ }
751+
752+ for (int j = 0 , k = 0 ; j < N; j = j + sizeFactor) {
753+ Value indexOp_j = rewriter.create <arith::ConstantIndexOp>(
754+ reductionForOp.getLoc (), j);
755+ auto valueRow = rewriterNewKForOp.create <vector::LoadOp>(
756+ kForOp .getLoc (), VectorType::get (32 , elementType),
757+ rhsClone->getResult (0 ),
758+ ValueRange{indexOp_c0, indexOp_c0, indexOp_j,
759+ indexOp_c0});
760+ for (int i = 0 ; i < M; i++) {
761+ auto dp = rewriter.create <mlir::x86vector::DotBF16Op>(
762+ kForOp .getLoc (), dstType, iterArgsNewKForOp[k],
763+ matf32[i], valueRow);
764+ k++;
765+ evenFMAs.push_back (dp);
766+ }
705767 }
706768 }
707769 }
@@ -840,7 +902,9 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
840902
841903 // uKernel lowering for AVX2 machines
842904 // Target: (a) f16 and bf16 for srf kind of machines
843- // (b) bf16 fallback + avx2 instructions
905+ // (b) bf16 fallback + avx2 instructions.
906+ // TODO: update lowering based on M & N. Now it is
907+ // default to M -> N
844908 if (srf || (fallback && avx2 && !avx512)) {
845909 // Load odd elements of A Matrix and store in a DS
846910 for (int i = 0 ; i < M; i++) {
0 commit comments