@@ -606,17 +606,29 @@ class tinyBLAS_Q0_AVX {
606606 case 0x44 :
607607 mc = 4 ;
608608 nc = 4 ;
609+ #if defined(__AVX2__) && defined(__F16C__)
610+ gemm4xN<4 >(m0, m, n0, n);
611+ #else
609612 gemm<4 , 4 >(m0, m, n0, n);
613+ #endif
610614 break ;
611615 case 0x43 :
612616 mc = 4 ;
613617 nc = 3 ;
618+ #if defined(__AVX2__) && defined(__F16C__)
619+ gemm4xN<3 >(m0, m, n0, n);
620+ #else
614621 gemm<4 , 3 >(m0, m, n0, n);
622+ #endif
615623 break ;
616624 case 0x34 :
617625 mc = 3 ;
618626 nc = 4 ;
627+ #if defined(__AVX2__) && defined(__F16C__)
628+ gemmMx4<3 >(m0, m, n0, n);
629+ #else
619630 gemm<3 , 4 >(m0, m, n0, n);
631+ #endif
620632 break ;
621633 case 0x33 :
622634 mc = 3 ;
@@ -626,26 +638,42 @@ class tinyBLAS_Q0_AVX {
626638 case 0x42 :
627639 mc = 4 ;
628640 nc = 2 ;
641+ #if defined(__AVX2__) && defined(__F16C__)
642+ gemm4xN<2 >(m0, m, n0, n);
643+ #else
629644 gemm<4 , 2 >(m0, m, n0, n);
645+ #endif
630646 break ;
631647 case 0x24 :
632648 mc = 2 ;
633649 nc = 4 ;
650+ #if defined(__AVX2__) && defined(__F16C__)
651+ gemmMx4<2 >(m0, m, n0, n);
652+ #else
634653 gemm<2 , 4 >(m0, m, n0, n);
654+ #endif
635655 break ;
636656#else
637657 case 0x44 :
638658 case 0x43 :
639659 case 0x42 :
640660 mc = 4 ;
641661 nc = 2 ;
662+ #if defined(__AVX2__) && defined(__F16C__)
663+ gemm4xN<2 >(m0, m, n0, n);
664+ #else
642665 gemm<4 , 2 >(m0, m, n0, n);
666+ #endif
643667 break ;
644668 case 0x34 :
645669 case 0x24 :
646670 mc = 2 ;
647671 nc = 4 ;
672+ #if defined(__AVX2__) && defined(__F16C__)
673+ gemmMx4<2 >(m0, m, n0, n);
674+ #else
648675 gemm<2 , 4 >(m0, m, n0, n);
676+ #endif
649677 break ;
650678 case 0x33 :
651679#endif
@@ -662,7 +690,11 @@ class tinyBLAS_Q0_AVX {
662690 case 0x41 :
663691 mc = 4 ;
664692 nc = 1 ;
693+ #if defined(__AVX2__) && defined(__F16C__)
694+ gemm4xN<1 >(m0, m, n0, n);
695+ #else
665696 gemm<4 , 1 >(m0, m, n0, n);
697+ #endif
666698 break ;
667699 case 0x22 :
668700 mc = 2 ;
@@ -672,7 +704,11 @@ class tinyBLAS_Q0_AVX {
672704 case 0x14 :
673705 mc = 1 ;
674706 nc = 4 ;
707+ #if defined(__AVX2__) && defined(__F16C__)
708+ gemmMx4<1 >(m0, m, n0, n);
709+ #else
675710 gemm<1 , 4 >(m0, m, n0, n);
711+ #endif
676712 break ;
677713 case 0x31 :
678714 mc = 3 ;
@@ -708,6 +744,119 @@ class tinyBLAS_Q0_AVX {
708744 mnpack (m0, m, np, n);
709745 }
710746
747+ #if defined(__AVX2__) && defined(__F16C__)
748+ // Templated functions for gemm of dimensions 4xN
749+ template <int RN>
750+ NOINLINE void gemm4xN (int64_t m0, int64_t m, int64_t n0, int64_t n) {
751+ int64_t ytiles = (m - m0) / 4 ;
752+ int64_t xtiles = (n - n0) / RN;
753+ int64_t tiles = xtiles * ytiles;
754+ int64_t duty = (tiles + nth - 1 ) / nth;
755+ int64_t start = duty * ith;
756+ int64_t end = start + duty;
757+ if (end > tiles)
758+ end = tiles;
759+ for (int64_t job = start; job < end; ++job) {
760+ int64_t ii = m0 + job / xtiles * 4 ;
761+ int64_t jj = n0 + job % xtiles * RN;
762+ __m256 Cv[RN][4 ] = {};
763+ for (int64_t l = 0 ; l < k; ++l) {
764+ uint64_t a_delta = ((uint64_t )A[lda * (ii + 3 ) + l].d << 48 ) | ((uint64_t )A[lda * (ii + 2 ) + l].d << 32 ) | ((uint64_t )A[lda * (ii + 1 ) + l].d << 16 ) | (A[lda * (ii + 0 ) + l].d );
765+ // Convert delta values for four blocks to float values
766+ __m128 da = _mm_cvtph_ps (_mm_set_epi64x (0 , a_delta));
767+ __m256i avec0 = load (A + lda * (ii + 0 ) + l);
768+ __m256i avec1 = load (A + lda * (ii + 1 ) + l);
769+ __m256i avec2 = load (A + lda * (ii + 2 ) + l);
770+ __m256i avec3 = load (A + lda * (ii + 3 ) + l);
771+ for (int64_t j = 0 ; j < RN; ++j) {
772+ __m128 db = _mm_set1_ps (unhalf (B[ldb * (jj + j) + l].d ));
773+ // Computation of product of delta values for four blocks and replicate it across 256 bit lane
774+ __m256 dvec = _mm256_castps128_ps256 (_mm_mul_ps (da, db));
775+ dvec = _mm256_permute2f128_ps (dvec ,dvec, 0 );
776+ // Computation of dot product and multiplication with appropriate delta value products
777+ Cv[j][0 ] = madd (_mm256_shuffle_ps (dvec, dvec, 0 ),
778+ updot (_mm256_sign_epi8 (avec0, avec0),
779+ _mm256_sign_epi8 (load (B + ldb * (jj + j) + l), avec0)),
780+ Cv[j][0 ]);
781+ Cv[j][1 ] = madd (_mm256_shuffle_ps (dvec, dvec, 85 ),
782+ updot (_mm256_sign_epi8 (avec1, avec1),
783+ _mm256_sign_epi8 (load (B + ldb * (jj + j) + l), avec1)),
784+ Cv[j][1 ]);
785+ Cv[j][2 ] = madd (_mm256_shuffle_ps (dvec, dvec, 170 ),
786+ updot (_mm256_sign_epi8 (avec2, avec2),
787+ _mm256_sign_epi8 (load (B + ldb * (jj + j) + l), avec2)),
788+ Cv[j][2 ]);
789+ Cv[j][3 ] = madd (_mm256_shuffle_ps (dvec, dvec, 255 ),
790+ updot (_mm256_sign_epi8 (avec3, avec3),
791+ _mm256_sign_epi8 (load (B + ldb * (jj + j) + l), avec3)),
792+ Cv[j][3 ]);
793+ }
794+ }
795+
796+ for (int64_t j = 0 ; j < RN; ++j)
797+ for (int64_t i = 0 ; i < 4 ; ++i)
798+ C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
799+ }
800+ }
801+
802+ // Templated functions for gemm of dimensions Mx4
803+ template <int RM>
804+ NOINLINE void gemmMx4 (int64_t m0, int64_t m, int64_t n0, int64_t n) {
805+ int64_t ytiles = (m - m0) / RM;
806+ int64_t xtiles = (n - n0) / 4 ;
807+ int64_t tiles = xtiles * ytiles;
808+ int64_t duty = (tiles + nth - 1 ) / nth;
809+ int64_t start = duty * ith;
810+ int64_t end = start + duty;
811+ if (end > tiles)
812+ end = tiles;
813+ for (int64_t job = start; job < end; ++job) {
814+ int64_t ii = m0 + job / xtiles * RM;
815+ int64_t jj = n0 + job % xtiles * 4 ;
816+ __m256 Cv[4 ][RM] = {};
817+ for (int64_t l = 0 ; l < k; ++l) {
818+ uint64_t b_delta = ((uint64_t )B[ldb * (jj + 3 ) + l].d << 48 ) | ((uint64_t )B[ldb * (jj + 2 ) + l].d << 32 ) | ((uint64_t )B[ldb * (jj + 1 ) + l].d << 16 ) | (B[ldb * (jj + 0 ) + l].d );
819+ // Convert delta values for four blocks to float values
820+ __m128 db = _mm_cvtph_ps (_mm_set_epi64x (0 , b_delta));
821+ __m256i bvec0 = load (B + ldb * (jj + 0 ) + l);
822+ __m256i bvec1 = load (B + ldb * (jj + 1 ) + l);
823+ __m256i bvec2 = load (B + ldb * (jj + 2 ) + l);
824+ __m256i bvec3 = load (B + ldb * (jj + 3 ) + l);
825+ for (int64_t i = 0 ; i < RM; ++i) {
826+ __m128 da = _mm_set1_ps (unhalf ((A[lda * (ii + i) + l].d )));
827+ // Computation of product of delta values for four blocks and replicate it across 256 bit lane
828+ __m256 dvec = _mm256_castps128_ps256 (_mm_mul_ps (da, db));
829+ dvec = _mm256_permute2f128_ps (dvec ,dvec, 0 );
830+ // Computation of dot product and multiplication with appropriate delta value products
831+ Cv[0 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 0 ),
832+ updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
833+ load (A + lda * (ii + i) + l)),
834+ _mm256_sign_epi8 (bvec0, load (A + lda * (ii + i) + l))),
835+ Cv[0 ][i]);
836+ Cv[1 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 85 ),
837+ updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
838+ load (A + lda * (ii + i) + l)),
839+ _mm256_sign_epi8 (bvec1, load (A + lda * (ii + i) + l))),
840+ Cv[1 ][i]);
841+ Cv[2 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 170 ),
842+ updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
843+ load (A + lda * (ii + i) + l)),
844+ _mm256_sign_epi8 (bvec2, load (A + lda * (ii + i) + l))),
845+ Cv[2 ][i]);
846+ Cv[3 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 255 ),
847+ updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
848+ load (A + lda * (ii + i) + l)),
849+ _mm256_sign_epi8 (bvec3, load (A + lda * (ii + i) + l))),
850+ Cv[3 ][i]);
851+ }
852+ }
853+ for (int64_t j = 0 ; j < 4 ; ++j)
854+ for (int64_t i = 0 ; i < RM; ++i)
855+ C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
856+ }
857+ }
858+ #endif
859+
711860 template <int RM, int RN>
712861 NOINLINE void gemm (int64_t m0, int64_t m, int64_t n0, int64_t n) {
713862 int64_t ytiles = (m - m0) / RM;
0 commit comments