@@ -173,6 +173,132 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
173173 return scale ;
174174}
175175
176+ static float make_qxg_quants (int n , int nmax , const float * restrict x , int8_t * restrict L , int rmse_type ,
177+ const float * restrict qw ) {
178+ float max = 0 ;
179+ float amax = 0 ;
180+ for (int i = 0 ; i < n ; ++ i ) {
181+ float ax = fabsf (x [i ]);
182+ if (ax > amax ) { amax = ax ; max = x [i ]; }
183+ }
184+ if (amax < GROUP_MAX_EPS ) { // all zero
185+ for (int i = 0 ; i < n ; ++ i ) {
186+ L [i ] = 0 ;
187+ }
188+ return 0.f ;
189+ }
190+ float iscale = - nmax / max ;
191+ if (rmse_type == 0 ) {
192+ for (int i = 0 ; i < n ; ++ i ) {
193+ int l = nearest_int (iscale * x [i ]);
194+ L [i ] = nmax + MAX (- nmax , MIN (nmax - 1 , l ));
195+ }
196+ return 1 /iscale ;
197+ }
198+ bool return_early = false;
199+ if (rmse_type < 0 ) {
200+ rmse_type = - rmse_type ;
201+ return_early = true;
202+ }
203+ float sumlx = 0 ;
204+ float suml2 = 0 ;
205+ #ifdef HAVE_BUGGY_APPLE_LINKER
206+ // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
207+ for (volatile int i = 0 ; i < n ; ++ i ) {
208+ #else
209+ for (int i = 0 ; i < n ; ++ i ) {
210+ #endif
211+ int l = nearest_int (iscale * x [i ]);
212+ l = MAX (- nmax , MIN (nmax - 1 , l ));
213+ L [i ] = l + nmax ;
214+ float w = qw ? qw [i ] : rmse_type == 1 ? x [i ] * x [i ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [i ]) : sqrtf (fabsf (x [i ]));
215+ sumlx += w * x [i ]* l ;
216+ suml2 += w * l * l ;
217+ }
218+ float scale = suml2 ? sumlx /suml2 : 0.0f ;
219+ if (return_early ) return suml2 > 0 ? 0.5f * (scale + 1 /iscale ) : 1 /iscale ;
220+ float best = scale * sumlx ;
221+ float best_sumlx = sumlx , best_suml2 = suml2 ;
222+ for (int is = -9 ; is <= 9 ; ++ is ) {
223+ iscale = - (nmax + 0.1f * is ) / max ;
224+ sumlx = suml2 = 0 ;
225+ for (int i = 0 ; i < n ; ++ i ) {
226+ int l = nearest_int (iscale * x [i ]);
227+ l = MAX (- nmax , MIN (nmax - 1 , l ));
228+ float w = qw ? qw [i ] : rmse_type == 1 ? x [i ] * x [i ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [i ]) : sqrtf (fabsf (x [i ]));
229+ sumlx += w * x [i ]* l ;
230+ suml2 += w * l * l ;
231+ }
232+ if (suml2 > 0 && sumlx * sumlx > best * suml2 ) {
233+ for (int i = 0 ; i < n ; ++ i ) {
234+ int l = nearest_int (iscale * x [i ]);
235+ L [i ] = nmax + MAX (- nmax , MIN (nmax - 1 , l ));
236+ }
237+ scale = sumlx /suml2 ; best = scale * sumlx ;
238+ best_sumlx = sumlx ; best_suml2 = suml2 ;
239+ }
240+ iscale = (nmax - 1 + 0.1f * is ) / max ;
241+ sumlx = suml2 = 0 ;
242+ for (int i = 0 ; i < n ; ++ i ) {
243+ int l = nearest_int (iscale * x [i ]);
244+ l = MAX (- nmax , MIN (nmax - 1 , l ));
245+ float w = qw ? qw [i ] : rmse_type == 1 ? x [i ] * x [i ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [i ]) : sqrtf (fabsf (x [i ]));
246+ sumlx += w * x [i ]* l ;
247+ suml2 += w * l * l ;
248+ }
249+ if (suml2 > 0 && sumlx * sumlx > best * suml2 ) {
250+ for (int i = 0 ; i < n ; ++ i ) {
251+ int l = nearest_int (iscale * x [i ]);
252+ L [i ] = nmax + MAX (- nmax , MIN (nmax - 1 , l ));
253+ }
254+ scale = sumlx /suml2 ; best = scale * sumlx ;
255+ best_sumlx = sumlx ; best_suml2 = suml2 ;
256+ }
257+ }
258+
259+ sumlx = best_sumlx ; suml2 = best_suml2 ;
260+ for (int iter = 0 ; iter < n * (2 * nmax - 1 ); ++ iter ) {
261+ float abs_gmax = 0 , gmax = 0 ;
262+ int best_j = -1 ;
263+ for (int j = 0 ; j < n ; ++ j ) {
264+ float w = qw ? qw [j ] : rmse_type == 1 ? x [j ] * x [j ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [j ]) : sqrtf (fabsf (x [j ]));
265+ int l = L [j ] - nmax ;
266+ float g = scale * w * (x [j ] - scale * l );
267+ if ((g > 0 && l < nmax - 1 ) || (g < 0 && l > - nmax )) {
268+ float ag = fabsf (g );
269+ if (ag > abs_gmax ) {
270+ abs_gmax = ag ; gmax = g ; best_j = j ;
271+ }
272+ }
273+ }
274+ if (best_j < 0 ) break ;
275+
276+ float new_sumlx = sumlx , new_suml2 = suml2 ;
277+ float w = qw ? qw [best_j ] : rmse_type == 1 ? x [best_j ] * x [best_j ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [best_j ]) : sqrtf (fabsf (x [best_j ]));
278+ int l = L [best_j ] - nmax ;
279+ if (gmax > 0 ) {
280+ new_sumlx += w * x [best_j ];
281+ new_suml2 += w * (2 * l + 1 );
282+ l += 1 ;
283+ } else {
284+ new_sumlx -= w * x [best_j ];
285+ new_suml2 -= w * (2 * l - 1 );
286+ l -= 1 ;
287+ }
288+ if (new_suml2 > 0 && new_sumlx * new_sumlx > best * new_suml2 ) {
289+ sumlx = new_sumlx ; suml2 = new_suml2 ;
290+ scale = sumlx /suml2 ; best = scale * sumlx ;
291+ L [best_j ] = l + nmax ;
292+ GGML_ASSERT (L [best_j ] >= 0 && L [best_j ] <= 2 * nmax - 1 );
293+ }
294+ else {
295+ break ;
296+ }
297+
298+ }
299+ return scale ;
300+ }
301+
176302static float make_q3_quants (int n , int nmax , const float * restrict x , int8_t * restrict L , bool do_rmse ) {
177303 float max = 0 ;
178304 float amax = 0 ;
@@ -634,6 +760,194 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
634760 }
635761}
636762
763+ static const int8_t iq4nl_index [241 ] = {
764+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 16 , 16 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
765+ 1 , 17 , 17 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 18 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 ,
766+ 3 , 3 , 3 , 3 , 3 , 3 , 19 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 20 , 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 ,
767+ 5 , 5 , 21 , 21 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 22 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 23 , 23 , 8 , 8 , 8 , 8 ,
768+ 8 , 8 , 8 , 8 , 8 , 8 , 24 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 25 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 26 , 26 ,
769+ 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 27 , 27 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 28 , 13 , 13 , 13 ,
770+ 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 29 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 ,
771+ 14 , 14 , 14 , 14 , 30 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15
772+ };
773+ static inline int best_index_iq4nl (const int8_t * values , float x ) {
774+ int ix = (int )x - values [0 ];
775+ if (ix < 0 || ix >= 241 ) return ix < 0 ? 0 : 15 ;
776+ ix = iq4nl_index [ix ];
777+ return ix < 16 ? ix : x - values [ix - 16 ] < values [ix - 15 ] - x ? ix - 16 : ix - 15 ;
778+ }
779+
780+ static void quantize_row_iq4_nl_g_impl (const int super_block_size , const int block_size , const float * restrict x ,
781+ ggml_fp16_t * dh , uint8_t * q4 , uint16_t * scales_h , uint8_t * scales_l ,
782+ float * scales , float * weight , uint8_t * L ,
783+ const int8_t * values ,
784+ const float * quant_weights ,
785+ const int ntry ) {
786+
787+ float sigma2 = 0 ;
788+ for (int j = 0 ; j < super_block_size ; ++ j ) sigma2 += x [j ]* x [j ];
789+ sigma2 *= 2.f /super_block_size ;
790+
791+ memset (q4 , 0 , super_block_size /2 );
792+ dh [0 ] = GGML_FP32_TO_FP16 (0.f );
793+
794+ float max_scale = 0 , amax_scale = 0 ;
795+ for (int ib = 0 ; ib < super_block_size /block_size ; ++ ib ) {
796+ const float * xb = x + ib * block_size ;
797+ uint8_t * Lb = L + ib * block_size ;
798+ if (quant_weights ) {
799+ const float * qw = quant_weights + ib * block_size ;
800+ for (int j = 0 ; j < block_size ; ++ j ) weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]);
801+ } else {
802+ for (int j = 0 ; j < block_size ; ++ j ) weight [j ] = xb [j ]* xb [j ];
803+ }
804+ float amax = 0 , max = 0 ;
805+ for (int j = 0 ; j < block_size ; ++ j ) {
806+ float ax = fabsf (xb [j ]);
807+ if (ax > amax ) {
808+ amax = ax ; max = xb [j ];
809+ }
810+ }
811+ if (amax < GROUP_MAX_EPS ) {
812+ scales [ib ] = 0 ;
813+ continue ;
814+ }
815+ float d = ntry > 0 ? - max /values [0 ] : max /values [0 ];
816+ float id = 1 /d ;
817+ float sumqx = 0 , sumq2 = 0 ;
818+ for (int j = 0 ; j < block_size ; ++ j ) {
819+ float al = id * xb [j ];
820+ int l = best_index_iq4nl (values , al );
821+ Lb [j ] = l ;
822+ float q = values [l ];
823+ float w = weight [j ];
824+ sumqx += w * q * xb [j ];
825+ sumq2 += w * q * q ;
826+ }
827+ d = sumqx /sumq2 ;
828+ float best = d * sumqx ;
829+ float best_sumqx = sumqx , best_sumq2 = sumq2 ;
830+ for (int itry = - ntry ; itry <= ntry ; ++ itry ) {
831+ id = (itry + values [0 ])/max ;
832+ sumqx = sumq2 = 0 ;
833+ for (int j = 0 ; j < block_size ; ++ j ) {
834+ float al = id * xb [j ];
835+ int l = best_index_iq4nl (values , al );
836+ float q = values [l ];
837+ float w = weight [j ];
838+ sumqx += w * q * xb [j ];
839+ sumq2 += w * q * q ;
840+ }
841+ if (sumq2 > 0 && sumqx * sumqx > best * sumq2 ) {
842+ d = sumqx /sumq2 ; best = d * sumqx ;
843+ best_sumqx = sumqx ; best_sumq2 = sumq2 ;
844+ for (int j = 0 ; j < block_size ; ++ j ) {
845+ float al = id * xb [j ];
846+ Lb [j ] = best_index_iq4nl (values , al );
847+ }
848+ }
849+ id = (itry + values [15 ])/max ;
850+ sumqx = sumq2 = 0 ;
851+ for (int j = 0 ; j < block_size ; ++ j ) {
852+ float al = id * xb [j ];
853+ int l = best_index_iq4nl (values , al );
854+ float q = values [l ];
855+ float w = weight [j ];
856+ sumqx += w * q * xb [j ];
857+ sumq2 += w * q * q ;
858+ }
859+ if (sumq2 > 0 && sumqx * sumqx > best * sumq2 ) {
860+ d = sumqx /sumq2 ; best = d * sumqx ;
861+ best_sumqx = sumqx ; best_sumq2 = sumq2 ;
862+ for (int j = 0 ; j < block_size ; ++ j ) {
863+ float al = id * xb [j ];
864+ Lb [j ] = best_index_iq4nl (values , al );
865+ }
866+ }
867+ }
868+ sumqx = best_sumqx ; sumq2 = best_sumq2 ;
869+ for (int iter = 0 ; iter < 32 * block_size ; ++ iter ) {
870+ float min_step = INFINITY ;
871+ int best_j = -1 ; int dir = 0 ;
872+ for (int j = 0 ; j < block_size ; ++ j ) {
873+ float w = weight [j ];
874+ float g = d * w * (xb [j ] - d * values [Lb [j ]]);
875+ if (g > 0 && Lb [j ] < 15 ) {
876+ float step = (values [Lb [j ]+ 1 ] - values [Lb [j ]])/g ;
877+ if (step < min_step ) {
878+ min_step = step ; best_j = j ; dir = 1 ;
879+ }
880+ }
881+ else if (g < 0 && Lb [j ] > 0 ) {
882+ float step = (values [Lb [j ]- 1 ] - values [Lb [j ]])/g ;
883+ if (step < min_step ) {
884+ min_step = step ; best_j = j ; dir = -1 ;
885+ }
886+ }
887+ }
888+ if (best_j < 0 ) break ;
889+
890+ float new_sumqx = sumqx , new_sumq2 = sumq2 ;
891+ float w = weight [best_j ];
892+ new_sumqx += w * xb [best_j ]* (values [Lb [best_j ]+ dir ] - values [Lb [best_j ]]);
893+ new_sumq2 += w * (values [Lb [best_j ]+ dir ]* values [Lb [best_j ]+ dir ] - values [Lb [best_j ]]* values [Lb [best_j ]]);
894+ if (new_sumq2 > 0 && new_sumqx * new_sumqx > best * new_sumq2 ) {
895+ sumqx = new_sumqx ; sumq2 = new_sumq2 ;
896+ d = sumqx /sumq2 ; best = d * sumqx ;
897+ Lb [best_j ] += dir ;
898+ }
899+ else {
900+ break ;
901+ }
902+ }
903+
904+ scales [ib ] = d ;
905+ float abs_d = fabsf (d );
906+ if (abs_d > amax_scale ) {
907+ amax_scale = abs_d ; max_scale = d ;
908+ }
909+ }
910+
911+ if (super_block_size /block_size > 1 ) {
912+ int nb = super_block_size /block_size ;
913+ memset (scales_h , 0 , ((nb + 7 )/8 )* sizeof (uint16_t ));
914+ float d = - max_scale /32 ;
915+ dh [0 ] = GGML_FP32_TO_FP16 (d );
916+ float id = d ? 1 /d : 0.f ;
917+ for (int ib = 0 ; ib < super_block_size /block_size ; ++ ib ) {
918+ int l = nearest_int (id * scales [ib ]);
919+ l = MAX (-32 , MIN (31 , l ));
920+ float dl = d * l ;
921+ float idl = dl ? 1 /dl : 0.f ;
922+ uint8_t * Lb = L + ib * block_size ;
923+ const float * xb = x + ib * block_size ;
924+ for (int j = 0 ; j < block_size ; ++ j ) {
925+ Lb [j ] = best_index_iq4nl (values , idl * xb [j ]);
926+ }
927+ l += 32 ;
928+ uint8_t l_l = l & 0xf ;
929+ uint8_t l_h = l >> 4 ;
930+ if (ib %2 == 0 ) scales_l [ib /2 ] = l_l ;
931+ else scales_l [ib /2 ] |= (l_l << 4 );
932+ scales_h [ib /8 ] |= (l_h << 2 * (ib %8 ));
933+ }
934+ } else {
935+ dh [0 ] = GGML_FP32_TO_FP16 (scales [0 ]);
936+ if (ntry > 0 ) {
937+ float id = scales [0 ] ? 1 /scales [0 ] : 0 ;
938+ for (int j = 0 ; j < super_block_size ; ++ j ) {
939+ L [j ] = best_index_iq4nl (values , id * x [j ]);
940+ }
941+ }
942+ }
943+
944+ for (int i = 0 ; i < super_block_size /32 ; ++ i ) {
945+ for (int j = 0 ; j < 16 ; ++ j ) {
946+ q4 [16 * i + j ] = L [32 * i + j ] | (L [32 * i + 16 + j ] << 4 );
947+ }
948+ }
949+ }
950+
637951// ---- Custom experiments ----
638952
639953struct fraction {
@@ -2636,6 +2950,16 @@ void anyrize_qx(const float * x, const float * w, float * v, int ne0, int ne1, i
26362950 }
26372951}
26382952
2953+ void anyrize_qxg (const float * x , const float * w , float * v , int ne0 , int ne1 , int nmax ) {
2954+ int8_t L [ne0 ];
2955+ for (int i = 0 ; i < ne1 ; ++ i ) {
2956+ float scale = make_qxg_quants (ne0 , nmax , x + ne0 * i , L , 1 , w ? w + i * ne0 : NULL );
2957+ for (int j = 0 ; j < ne0 ; ++ j ) {
2958+ v [i * ne0 + j ] = (L [j ] - nmax ) * scale ;
2959+ }
2960+ }
2961+ }
2962+
26392963void anyrize_qkxs (const float * x , const float * w , float * v , int ne0 , int ne1 , int nmin , int nmax , bool signed_scale ) {
26402964 struct fraction Faux [ne0 * MAX (abs (nmin ), abs (nmax ))];
26412965 int8_t L [ne0 ];
@@ -2832,6 +3156,23 @@ void anyrize_iq4nl(const float * x, const float * w, float * v, int ne0, int ne1
28323156 }
28333157}
28343158
3159+ void anyrize_iq4nl_g (const float * x , const float * w , float * v , int ne0 , int ne1 ) {
3160+ uint8_t L [ne0 ];
3161+ uint8_t Laux [ne0 ];
3162+ ggml_fp16_t unused_dh ;
3163+ uint8_t unused_q4 [ne0 ];
3164+ uint16_t unused_h ;
3165+ uint8_t * unused_l = NULL ;
3166+ float weight [ne0 ];
3167+ for (int i = 0 ; i < ne1 ; ++ i ) {
3168+ float scale = 0.0f ;
3169+ quantize_row_iq4_nl_g_impl (ne0 , ne0 , x + i * ne0 , & unused_dh , unused_q4 , & unused_h , unused_l , & scale , weight , L , kvalues_iq4nl , w ? w + i * ne0 : NULL , 7 );
3170+ for (int j = 0 ; j < ne0 ; ++ j ) {
3171+ v [i * ne0 + j ] = kvalues_iq4nl [L [j ]] * scale ;
3172+ }
3173+ }
3174+ }
3175+
28353176void anyrize_qkxs_iq4nl (const float * x , const float * w , float * v , int ne0 , int ne1 ) {
28363177 uint8_t L [ne0 ];
28373178 uint8_t Laux [ne0 ];
0 commit comments