@@ -87,13 +87,12 @@ struct EvenSignHelper {
8787 const __m256i shifts = _mm256_set_epi32(21 , 14 , 7 , 0 , 21 , 14 , 7 , 0 );
8888 const __m256i mask = _mm256_set1_epi32(127 );
8989 const __m256i mone = _mm256_set1_epi32(1 );
90- #else
90+ #endif
9191 inline void sign_value (uint32_t aux32, __m256i& value) const {
9292 auto signs = _mm256_set_epi64x (keven_signs[(aux32 >> 21 ) & 127 ], keven_signs[(aux32 >> 14 ) & 127 ],
9393 keven_signs[(aux32 >> 7 ) & 127 ], keven_signs[(aux32 >> 0 ) & 127 ]);
9494 value = _mm256_sign_epi8 (value, signs);
9595 }
96- #endif
9796};
9897
9998struct SignHelper {
@@ -144,6 +143,35 @@ struct SignHelper {
144143 const __m256i mone = _mm256_set1_epi8(1 );
145144};
146145
146+ // for (int i = 0; i < nb; ++i) {
147+ //
148+ // __m256i sumi[nrc_y], all_scales;
149+ // //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
150+ // __m256i mins;
151+ // float dmin = deq.new_block(i, &all_scales, mins);
152+ // for (int iy = 0; iy < nrc_y; ++iy) {
153+ // auto bsums = q8.load_bsums(iy, i);
154+ // auto prod = _mm256_madd_epi16(mins, bsums);
155+ // accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
156+ // }
157+ //
158+ // for (int j = 0; j < QK_K/128; ++j) {
159+ // deq.prepare(i, j);
160+ // set_scales_8(&all_scales, j, scales);
161+ // //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
162+ // multiply_add(deq.bits, scales, j, i, q8, sumi);
163+ // }
164+ // for (int iy = 0; iy < nrc_y; ++iy) {
165+ // const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
166+ // accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
167+ // }
168+ // }
169+ //
170+ // for (int iy = 0; iy < nrc_y; ++iy) {
171+ // info.store(ix, iy, hsum_float_8(accd[iy]));
172+ // }
173+ // }
174+
147175struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
148176 DequantizerIQ2XXS (const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
149177
@@ -165,6 +193,16 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
165193 auto sc16 = load_scales (i);
166194 scales[0 ] = MM256_SET_M128I (sc16, sc16);
167195 }
196+ inline void new_block_f (int i, __m256 * scales) {
197+ auto sc16 = load_scales (i);
198+ auto scf = _mm256_mul_ps (_mm256_set1_ps (d), _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (sc16)));
199+ auto scf_l = _mm256_castps256_ps128 (scf);
200+ auto scf_h = _mm256_extractf128_ps (scf, 1 );
201+ scales[0 ] = _mm256_set_m128 (scf_l, scf_l);
202+ scales[1 ] = _mm256_set_m128 (scf_h, scf_h);
203+ scales[2 ] = _mm256_mul_ps (scf, _mm256_set1_ps (-minv));
204+ }
205+
168206 inline float new_block (int i, __m256i * scales, __m256i& mins) {
169207 auto sc16 = load_scales (i);
170208 mins = scb.shuffle (sc16);
@@ -729,6 +767,130 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
729767 }
730768}
731769
770+ template <typename Dequantizer, int nrc_y>
771+ static void mul_mat_qX_K_q8_2_IQ_N (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
772+ static_assert (Dequantizer::num_blocks == 8 );
773+ const int nb = n / QK_K;
774+ Q8<nrc_y, block_q8_2_x4> q8 (info);
775+ Dequantizer deq (vx, bx);
776+ __m256 scales[3 ];
777+ __m256 accd[nrc_y];
778+ __m256i sumi[4 ];
779+
780+ for (int ix = 0 ; ix < nrc_x; ++ix) {
781+
782+ for (int iy = 0 ; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps ();
783+
784+ deq.new_row (ix);
785+
786+ for (int i = 0 ; i < nb; ++i) {
787+
788+ deq.new_block_f (i, scales);
789+ for (int iy = 0 ; iy < nrc_y; ++iy) {
790+ auto my1 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *i+0 ].d + 4 )));
791+ auto my2 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *i+1 ].d + 4 )));
792+ auto my = _mm256_castsi256_ps (_mm256_slli_epi32 (MM256_SET_M128I (my2, my1), 16 ));
793+ accd[iy] = _mm256_fmadd_ps (scales[2 ], my, accd[iy]);
794+ }
795+
796+ for (int j = 0 ; j < QK_K/128 ; ++j) {
797+ deq.prepare (i, j);
798+ auto & values = deq.bits .values ;
799+ for (int iy = 0 ; iy < nrc_y; ++iy) {
800+ auto qs = q8.y [iy][2 *i+j].qs ;
801+ #ifdef HAVE_FANCY_SIMD
802+ sumi[0 ] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), values[0 ], _mm256_loadu_si256 ((const __m256i*)qs+0 ));
803+ sumi[1 ] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), values[1 ], _mm256_loadu_si256 ((const __m256i*)qs+1 ));
804+ sumi[2 ] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), values[2 ], _mm256_loadu_si256 ((const __m256i*)qs+2 ));
805+ sumi[3 ] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), values[3 ], _mm256_loadu_si256 ((const __m256i*)qs+3 ));
806+ #else
807+ sumi[0 ] = _mm256_madd_epi16 (_mm256_set1_epi16 (1 ), _mm256_maddubs_epi16 (values[0 ], _mm256_loadu_si256 ((const __m256i*)qs+0 )));
808+ sumi[1 ] = _mm256_madd_epi16 (_mm256_set1_epi16 (1 ), _mm256_maddubs_epi16 (values[1 ], _mm256_loadu_si256 ((const __m256i*)qs+1 )));
809+ sumi[2 ] = _mm256_madd_epi16 (_mm256_set1_epi16 (1 ), _mm256_maddubs_epi16 (values[2 ], _mm256_loadu_si256 ((const __m256i*)qs+2 )));
810+ sumi[3 ] = _mm256_madd_epi16 (_mm256_set1_epi16 (1 ), _mm256_maddubs_epi16 (values[3 ], _mm256_loadu_si256 ((const __m256i*)qs+3 )));
811+ #endif
812+ sumi[0 ] = _mm256_add_epi32 (_mm256_unpacklo_epi32 (sumi[0 ], sumi[1 ]), _mm256_unpackhi_epi32 (sumi[0 ], sumi[1 ]));
813+ sumi[2 ] = _mm256_add_epi32 (_mm256_unpacklo_epi32 (sumi[2 ], sumi[3 ]), _mm256_unpackhi_epi32 (sumi[2 ], sumi[3 ]));
814+ sumi[0 ] = _mm256_add_epi32 (_mm256_unpacklo_epi64 (sumi[0 ], sumi[2 ]), _mm256_unpackhi_epi64 (sumi[0 ], sumi[2 ]));
815+ auto d4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)q8.y [iy][2 *i+j].d )), 16 ));
816+ auto dy = _mm256_set_m128 (d4, d4);
817+ accd[iy] = _mm256_fmadd_ps (_mm256_mul_ps (scales[j], dy), _mm256_cvtepi32_ps (sumi[0 ]), accd[iy]);
818+ }
819+ }
820+ }
821+
822+ for (int iy = 0 ; iy < nrc_y; ++iy) {
823+ info.store (ix, iy, hsum_float_8 (accd[iy]));
824+ }
825+ }
826+ }
827+
828+ template <int nrc_y>
829+ static void mul_mat_iq2_xxs_q8_2_IQ_N (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
830+ const int nb = n / QK_K;
831+ __m256 scales[2 ];
832+ __m256 accd[nrc_y];
833+ __m256i sumi[4 ];
834+ __m256i xv[4 ];
835+ EvenSignHelper esh;
836+
837+ for (int ix = 0 ; ix < nrc_x; ++ix) {
838+
839+ for (int iy = 0 ; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps ();
840+
841+ const block_iq2_xxs * x = (const block_iq2_xxs *)((const char *)vx + ix*bx);
842+
843+ for (int i = 0 ; i < nb; ++i) {
844+ const float d = GGML_FP16_TO_FP32 (x[i].d )*0 .125f ;
845+ const uint16_t * a16 = x[i].qs ;
846+ auto sc16 = _mm_set_epi16 (a16[31 ], a16[27 ], a16[23 ], a16[19 ], a16[15 ], a16[11 ], a16[7 ], a16[3 ]);
847+ sc16 = _mm_or_si128 (_mm_slli_epi16 (_mm_srli_epi16 (sc16, 12 ), 1 ), _mm_set1_epi16 (1 ));
848+ auto sc32 = _mm256_cvtepi16_epi32 (sc16);
849+ auto all_scales = _mm256_mul_ps (_mm256_set1_ps (d), _mm256_cvtepi32_ps (sc32));
850+ auto all_mins = _mm256_mul_ps (all_scales, _mm256_set1_ps (-43 .f ));
851+ for (int iy = 0 ; iy < nrc_y; ++iy) {
852+ auto y = (const block_q8_2_x4 *)info.src1_row (iy);
853+ auto my1 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(y[2 *i+0 ].d + 4 )));
854+ auto my2 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(y[2 *i+1 ].d + 4 )));
855+ auto my = _mm256_castsi256_ps (_mm256_slli_epi32 (MM256_SET_M128I (my2, my1), 16 ));
856+ accd[iy] = _mm256_fmadd_ps (all_mins, my, accd[iy]);
857+ }
858+ auto scales_l = _mm256_castps256_ps128 (all_scales);
859+ auto scales_h = _mm256_extractf128_ps (all_scales, 1 );
860+ scales[0 ] = _mm256_set_m128 (scales_l, scales_l);
861+ scales[1 ] = _mm256_set_m128 (scales_h, scales_h);
862+
863+ for (int j = 0 ; j < QK_K/128 ; ++j) {
864+ const uint8_t * a8 = (const uint8_t *)(a16 + 16 *j);
865+ for (int k = 0 ; k < 4 ; ++k) {
866+ auto a8k = a8 + 8 *k;
867+ xv[k] = _mm256_set_epi64x (iq2xxs_grid[a8k[3 ]], iq2xxs_grid[a8k[2 ]], iq2xxs_grid[a8k[1 ]], iq2xxs_grid[a8k[0 ]]);
868+ uint32_t aux32; std::memcpy (&aux32, a8k+4 , sizeof (uint32_t ));
869+ esh.sign_value (aux32, xv[k]);
870+ xv[k] = _mm256_add_epi8 (xv[k], _mm256_set1_epi8 (43 ));
871+ }
872+ for (int iy = 0 ; iy < nrc_y; ++iy) {
873+ auto y = (const block_q8_2_x4 *)info.src1_row (iy);
874+ sumi[0 ] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), xv[0 ], _mm256_loadu_si256 ((const __m256i*)y[2 *i+j].qs +0 ));
875+ sumi[1 ] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), xv[1 ], _mm256_loadu_si256 ((const __m256i*)y[2 *i+j].qs +1 ));
876+ sumi[2 ] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), xv[2 ], _mm256_loadu_si256 ((const __m256i*)y[2 *i+j].qs +2 ));
877+ sumi[3 ] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), xv[3 ], _mm256_loadu_si256 ((const __m256i*)y[2 *i+j].qs +3 ));
878+ sumi[0 ] = _mm256_add_epi32 (_mm256_unpacklo_epi32 (sumi[0 ], sumi[1 ]), _mm256_unpackhi_epi32 (sumi[0 ], sumi[1 ]));
879+ sumi[2 ] = _mm256_add_epi32 (_mm256_unpacklo_epi32 (sumi[2 ], sumi[3 ]), _mm256_unpackhi_epi32 (sumi[2 ], sumi[3 ]));
880+ sumi[0 ] = _mm256_add_epi32 (_mm256_unpacklo_epi64 (sumi[0 ], sumi[2 ]), _mm256_unpackhi_epi64 (sumi[0 ], sumi[2 ]));
881+ auto d4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)y[2 *i+j].d )), 16 ));
882+ auto dy = _mm256_set_m128 (d4, d4);
883+ accd[iy] = _mm256_fmadd_ps (_mm256_mul_ps (scales[j], dy), _mm256_cvtepi32_ps (sumi[0 ]), accd[iy]);
884+ }
885+ }
886+ }
887+
888+ for (int iy = 0 ; iy < nrc_y; ++iy) {
889+ info.store (ix, iy, hsum_float_8 (accd[iy]));
890+ }
891+ }
892+ }
893+
732894template <typename Dequantizer, int nrc_y>
733895static void mul_mat_qX_K_q8_K_IQ (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
734896 assert (n % QK_K == 0 );
@@ -1560,6 +1722,55 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
15601722 }
15611723}
15621724
1725+ void iqk_convert_iq2_xxs_q8_0_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1726+ GGML_ASSERT (n%QK_K == 0 );
1727+ GGML_ASSERT (nrc_x%8 == 0 );
1728+
1729+ int nb = n/QK_K;
1730+
1731+ const block_iq2_xxs * x8[8 ];
1732+
1733+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
1734+
1735+ ggml_half dh[8 ];
1736+ uint16_t all_ls[64 ];
1737+ EvenSignHelper esh;
1738+
1739+ uint32_t block[8 ];
1740+ uint32_t aux32[2 ];
1741+ const uint8_t * aux8 = (const uint8_t *)aux32;
1742+
1743+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
1744+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq2_xxs *)((const char *)vx + (ix + k)*bx);
1745+ for (int i = 0 ; i < nb; ++i) {
1746+ // TODO: simdify
1747+ for (int k = 0 ; k < 8 ; ++k) {
1748+ dh[k] = x8[k][i].d ;
1749+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
1750+ std::memcpy (aux32, x8[k][i].qs + 4 *ib32, 2 *sizeof (uint32_t ));
1751+ all_ls[8 *ib32 + k] = (2 *(aux32[1 ] >> 28 ) + 1 );
1752+ auto value = _mm256_set_epi64x (iq2xxs_grid[aux8[3 ]], iq2xxs_grid[aux8[2 ]], iq2xxs_grid[aux8[1 ]], iq2xxs_grid[aux8[0 ]]);
1753+ esh.sign_value (aux32[1 ], value);
1754+ _mm256_storeu_si256 ((__m256i *)block, value);
1755+ auto qs = (uint32_t *)y[ib32].qs ;
1756+ for (int l = 0 ; l < 4 ; ++l) {
1757+ qs[8 *l + k + 0 ] = block[l + 0 ];
1758+ qs[8 *l + k + 32 ] = block[l + 4 ];
1759+ }
1760+ }
1761+ }
1762+ auto vd = _mm256_mul_ps (_mm256_set1_ps (0 .125f ), _mm256_cvtph_ps (_mm_loadu_si128 ((const __m128i *)dh)));
1763+ for (int ib32 = 0 ; ib32 < QK_K/32 ; ++ib32) {
1764+ auto iscales16 = _mm_loadu_si128 ((const __m128i *)all_ls + ib32);
1765+ auto iscales32 = _mm256_cvtepi16_epi32 (iscales16);
1766+ auto scales = _mm256_mul_ps (vd, _mm256_cvtepi32_ps (iscales32));
1767+ _mm_storeu_si128 ((__m128i *)y[ib32].d , _mm256_cvtps_ph (scales, _MM_FROUND_TO_NEAREST_INT));
1768+ }
1769+ y += QK_K/32 ;
1770+ }
1771+ }
1772+ }
1773+
15631774template <typename Dequantizer> void set_functions (std::array<mul_mat_t , IQK_MAX_NY>& funcs) {
15641775 funcs[0 ] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1 >;
15651776 funcs[1 ] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2 >;
@@ -1575,7 +1786,19 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX
15751786
15761787bool iqk_set_kernels_iquants (int ne00, int typeA, int typeB, std::array<mul_mat_t , IQK_MAX_NY>& kernels, mul_mat_t & func16) {
15771788
1578- if (ne00%QK_K != 0 || ggml_type (typeB) != GGML_TYPE_Q8_K) {
1789+ if (ne00%QK_K != 0 ) return false ;
1790+
1791+ if (ggml_type (typeA) == GGML_TYPE_IQ2_XXS) {
1792+ if (ggml_type (typeB) == GGML_TYPE_Q8_2_X4) {
1793+ IQK_SET_MUL_MAT_FUNCTIONS_T (mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ2XXS, kernels);
1794+ // IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_q8_2_IQ_N, kernels);
1795+ func16 = nullptr ;
1796+ return true ;
1797+ }
1798+ return false ;
1799+ }
1800+
1801+ if (ggml_type (typeB) != GGML_TYPE_Q8_K) {
15791802 return false ;
15801803 }
15811804
@@ -1629,6 +1852,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
16291852
16301853}
16311854
1855+ bool iqk_convert_iquants_q80_r8 (int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1856+ if (n%QK_K != 0 || nrc_x%8 != 0 ) return false ;
1857+ switch (ggml_type (type)) {
1858+ case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_0_r8 (n, vx, bx, vy, nrc_x); break ;
1859+ default : return false ;
1860+ }
1861+ return true ;
1862+ }
1863+
16321864#else
16331865// --------------------------------------- __aarch64__ ---------------------------------------------
16341866
0 commit comments