@@ -606,17 +606,29 @@ class tinyBLAS_Q0_AVX {
606606        case  0x44 :
607607            mc = 4 ;
608608            nc = 4 ;
609+ #if  defined(__AVX2__) && defined(__F16C__)
610+             gemm4xN<4 >(m0, m, n0, n);
611+ #else 
609612            gemm<4 , 4 >(m0, m, n0, n);
613+ #endif 
610614            break ;
611615        case  0x43 :
612616            mc = 4 ;
613617            nc = 3 ;
618+ #if  defined(__AVX2__) && defined(__F16C__)
619+             gemm4xN<3 >(m0, m, n0, n);
620+ #else 
614621            gemm<4 , 3 >(m0, m, n0, n);
622+ #endif 
615623            break ;
616624        case  0x34 :
617625            mc = 3 ;
618626            nc = 4 ;
627+ #if  defined(__AVX2__) && defined(__F16C__)
628+             gemmMx4<3 >(m0, m, n0, n);
629+ #else 
619630            gemm<3 , 4 >(m0, m, n0, n);
631+ #endif 
620632            break ;
621633        case  0x33 :
622634            mc = 3 ;
@@ -626,26 +638,42 @@ class tinyBLAS_Q0_AVX {
626638        case  0x42 :
627639            mc = 4 ;
628640            nc = 2 ;
641+ #if  defined(__AVX2__) && defined(__F16C__)
642+             gemm4xN<2 >(m0, m, n0, n);
643+ #else 
629644            gemm<4 , 2 >(m0, m, n0, n);
645+ #endif 
630646            break ;
631647        case  0x24 :
632648            mc = 2 ;
633649            nc = 4 ;
650+ #if  defined(__AVX2__) && defined(__F16C__)
651+             gemmMx4<2 >(m0, m, n0, n);
652+ #else 
634653            gemm<2 , 4 >(m0, m, n0, n);
654+ #endif 
635655            break ;
636656#else 
637657        case  0x44 :
638658        case  0x43 :
639659        case  0x42 :
640660            mc = 4 ;
641661            nc = 2 ;
662+ #if  defined(__AVX2__) && defined(__F16C__)
663+             gemm4xN<2 >(m0, m, n0, n);
664+ #else 
642665            gemm<4 , 2 >(m0, m, n0, n);
666+ #endif 
643667            break ;
644668        case  0x34 :
645669        case  0x24 :
646670            mc = 2 ;
647671            nc = 4 ;
672+ #if  defined(__AVX2__) && defined(__F16C__)
673+             gemmMx4<2 >(m0, m, n0, n);
674+ #else 
648675            gemm<2 , 4 >(m0, m, n0, n);
676+ #endif 
649677            break ;
650678        case  0x33 :
651679#endif 
@@ -662,7 +690,11 @@ class tinyBLAS_Q0_AVX {
662690        case  0x41 :
663691            mc = 4 ;
664692            nc = 1 ;
693+ #if  defined(__AVX2__) && defined(__F16C__)
694+             gemm4xN<1 >(m0, m, n0, n);
695+ #else 
665696            gemm<4 , 1 >(m0, m, n0, n);
697+ #endif 
666698            break ;
667699        case  0x22 :
668700            mc = 2 ;
@@ -672,7 +704,11 @@ class tinyBLAS_Q0_AVX {
672704        case  0x14 :
673705            mc = 1 ;
674706            nc = 4 ;
707+ #if  defined(__AVX2__) && defined(__F16C__)
708+             gemmMx4<1 >(m0, m, n0, n);
709+ #else 
675710            gemm<1 , 4 >(m0, m, n0, n);
711+ #endif 
676712            break ;
677713        case  0x31 :
678714            mc = 3 ;
@@ -708,6 +744,119 @@ class tinyBLAS_Q0_AVX {
708744        mnpack (m0, m, np, n);
709745    }
710746
747+ #if  defined(__AVX2__) && defined(__F16C__)
748+ //  Templated functions for gemm of dimensions 4xN
749+     template  <int  RN>
750+     NOINLINE void  gemm4xN (int64_t  m0, int64_t  m, int64_t  n0, int64_t  n) {
751+         int64_t  ytiles = (m - m0) / 4 ;
752+         int64_t  xtiles = (n - n0) / RN;
753+         int64_t  tiles = xtiles * ytiles;
754+         int64_t  duty = (tiles + nth - 1 ) / nth;
755+         int64_t  start = duty * ith;
756+         int64_t  end = start + duty;
757+         if  (end > tiles)
758+             end = tiles;
759+         for  (int64_t  job = start; job < end; ++job) {
760+             int64_t  ii = m0 + job / xtiles * 4 ;
761+             int64_t  jj = n0 + job % xtiles * RN;
762+             __m256 Cv[RN][4 ] = {};
763+             for  (int64_t  l = 0 ; l < k; ++l) {
764+                 uint64_t  a_delta = ((uint64_t )A[lda * (ii + 3 ) + l].d  << 48 ) | ((uint64_t )A[lda * (ii + 2 ) + l].d  << 32 ) | ((uint64_t )A[lda * (ii + 1 ) + l].d  << 16 ) | (A[lda * (ii + 0 ) + l].d );
765+                 //  Convert delta values for four blocks to float values
766+                 __m128 da = _mm_cvtph_ps (_mm_set_epi64x (0 , a_delta));
767+                 __m256i avec0 = load (A + lda * (ii + 0 ) + l);
768+                 __m256i avec1 = load (A + lda * (ii + 1 ) + l);
769+                 __m256i avec2 = load (A + lda * (ii + 2 ) + l);
770+                 __m256i avec3 = load (A + lda * (ii + 3 ) + l);
771+                 for  (int64_t  j = 0 ; j < RN; ++j) {
772+                         __m128 db = _mm_set1_ps (unhalf (B[ldb * (jj + j) + l].d ));
773+                         //  Computation of product of delta values for four blocks and replicate it across 256 bit lane
774+                         __m256 dvec =  _mm256_castps128_ps256 (_mm_mul_ps (da, db));
775+                         dvec = _mm256_permute2f128_ps (dvec ,dvec, 0 );
776+                         //  Computation of dot product and multiplication with appropriate delta value products
777+                         Cv[j][0 ] = madd (_mm256_shuffle_ps (dvec, dvec, 0 ),
778+                                     updot (_mm256_sign_epi8 (avec0, avec0),
779+                                           _mm256_sign_epi8 (load (B + ldb * (jj + j) + l), avec0)),
780+                                     Cv[j][0 ]);
781+                         Cv[j][1 ] = madd (_mm256_shuffle_ps (dvec, dvec, 85 ),
782+                                     updot (_mm256_sign_epi8 (avec1, avec1),
783+                                             _mm256_sign_epi8 (load (B + ldb * (jj + j) + l), avec1)),
784+                                     Cv[j][1 ]);
785+                         Cv[j][2 ] = madd (_mm256_shuffle_ps (dvec, dvec, 170 ),
786+                                     updot (_mm256_sign_epi8 (avec2, avec2),
787+                                             _mm256_sign_epi8 (load (B + ldb * (jj + j) + l), avec2)),
788+                                     Cv[j][2 ]);
789+                         Cv[j][3 ] = madd (_mm256_shuffle_ps (dvec, dvec, 255 ),
790+                                     updot (_mm256_sign_epi8 (avec3, avec3),
791+                                             _mm256_sign_epi8 (load (B + ldb * (jj + j) + l), avec3)),
792+                                     Cv[j][3 ]);
793+                 }
794+             }
795+ 
796+             for  (int64_t  j = 0 ; j < RN; ++j)
797+                 for  (int64_t  i = 0 ; i < 4 ; ++i)
798+                     C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
799+         }
800+     }
801+ 
802+     //  Templated functions for gemm of dimensions Mx4
803+     template  <int  RM>
804+     NOINLINE void  gemmMx4 (int64_t  m0, int64_t  m, int64_t  n0, int64_t  n) {
805+         int64_t  ytiles = (m - m0) / RM;
806+         int64_t  xtiles = (n - n0) / 4 ;
807+         int64_t  tiles = xtiles * ytiles;
808+         int64_t  duty = (tiles + nth - 1 ) / nth;
809+         int64_t  start = duty * ith;
810+         int64_t  end = start + duty;
811+         if  (end > tiles)
812+             end = tiles;
813+         for  (int64_t  job = start; job < end; ++job) {
814+             int64_t  ii = m0 + job / xtiles * RM;
815+             int64_t  jj = n0 + job % xtiles * 4 ;
816+             __m256 Cv[4 ][RM] = {};
817+             for  (int64_t  l = 0 ; l < k; ++l) {
818+                 uint64_t  b_delta = ((uint64_t )B[ldb * (jj + 3 ) + l].d  << 48 ) | ((uint64_t )B[ldb * (jj + 2 ) + l].d  << 32 ) | ((uint64_t )B[ldb * (jj + 1 ) + l].d  << 16 ) | (B[ldb * (jj + 0 ) + l].d );
819+                 //  Convert delta values for four blocks to float values
820+                 __m128 db = _mm_cvtph_ps (_mm_set_epi64x (0 , b_delta));
821+                 __m256i bvec0 = load (B + ldb * (jj + 0 ) + l);
822+                 __m256i bvec1 = load (B + ldb * (jj + 1 ) + l);
823+                 __m256i bvec2 = load (B + ldb * (jj + 2 ) + l);
824+                 __m256i bvec3 = load (B + ldb * (jj + 3 ) + l);
825+                 for  (int64_t  i = 0 ; i < RM; ++i) {
826+                     __m128 da = _mm_set1_ps (unhalf ((A[lda * (ii + i) + l].d )));
827+                     //  Computation of product of delta values for four blocks and replicate it across 256 bit lane
828+                     __m256 dvec =  _mm256_castps128_ps256 (_mm_mul_ps (da, db));
829+                     dvec = _mm256_permute2f128_ps (dvec ,dvec, 0 );
830+                     //  Computation of dot product and multiplication with appropriate delta value products
831+                     Cv[0 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 0 ),
832+                                     updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
833+                                                             load (A + lda * (ii + i) + l)),
834+                                             _mm256_sign_epi8 (bvec0, load (A + lda * (ii + i) + l))),
835+                                     Cv[0 ][i]);
836+                     Cv[1 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 85 ),
837+                                     updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
838+                                                             load (A + lda * (ii + i) + l)),
839+                                             _mm256_sign_epi8 (bvec1, load (A + lda * (ii + i) + l))),
840+                                     Cv[1 ][i]);
841+                     Cv[2 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 170 ),
842+                                     updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
843+                                                             load (A + lda * (ii + i) + l)),
844+                                             _mm256_sign_epi8 (bvec2, load (A + lda * (ii + i) + l))),
845+                                     Cv[2 ][i]);
846+                     Cv[3 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 255 ),
847+                                     updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
848+                                                             load (A + lda * (ii + i) + l)),
849+                                             _mm256_sign_epi8 (bvec3, load (A + lda * (ii + i) + l))),
850+                                     Cv[3 ][i]);
851+                 }
852+             }
853+             for  (int64_t  j = 0 ; j < 4 ; ++j)
854+                 for  (int64_t  i = 0 ; i < RM; ++i)
855+                     C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
856+         }
857+     }
858+ #endif 
859+ 
711860    template  <int  RM, int  RN>
712861    NOINLINE void  gemm (int64_t  m0, int64_t  m, int64_t  n0, int64_t  n) {
713862        int64_t  ytiles = (m - m0) / RM;
0 commit comments