@@ -564,10 +564,8 @@ static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8
564564 float scale = suml2 ? sumlx /suml2 : 0.0f ;
565565 if (return_early ) return suml2 > 0 ? 0.5f * (scale + 1 /iscale ) : 1 /iscale ;
566566 float best = scale * sumlx ;
567+ float best_sumlx = sumlx , best_suml2 = suml2 ;
567568 for (int is = -9 ; is <= 9 ; ++ is ) {
568- if (is == 0 ) {
569- continue ;
570- }
571569 iscale = - (nmax + 0.1f * is ) / max ;
572570 sumlx = suml2 = 0 ;
573571 for (int i = 0 ; i < n ; ++ i ) {
@@ -583,7 +581,66 @@ static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8
583581 L [i ] = nmax + MAX (- nmax , MIN (nmax - 1 , l ));
584582 }
585583 scale = sumlx /suml2 ; best = scale * sumlx ;
584+ best_sumlx = sumlx ; best_suml2 = suml2 ;
585+ }
586+ iscale = (nmax - 1 + 0.1f * is ) / max ;
587+ sumlx = suml2 = 0 ;
588+ for (int i = 0 ; i < n ; ++ i ) {
589+ int l = nearest_int (iscale * x [i ]);
590+ l = MAX (- nmax , MIN (nmax - 1 , l ));
591+ float w = qw ? qw [i ] : rmse_type == 1 ? x [i ] * x [i ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [i ]) : sqrtf (fabsf (x [i ]));
592+ sumlx += w * x [i ]* l ;
593+ suml2 += w * l * l ;
586594 }
595+ if (suml2 > 0 && sumlx * sumlx > best * suml2 ) {
596+ for (int i = 0 ; i < n ; ++ i ) {
597+ int l = nearest_int (iscale * x [i ]);
598+ L [i ] = nmax + MAX (- nmax , MIN (nmax - 1 , l ));
599+ }
600+ scale = sumlx /suml2 ; best = scale * sumlx ;
601+ best_sumlx = sumlx ; best_suml2 = suml2 ;
602+ }
603+ }
604+
605+ sumlx = best_sumlx ; suml2 = best_suml2 ;
606+ for (int iter = 0 ; iter < n * (2 * nmax - 1 ); ++ iter ) {
607+ float abs_gmax = 0 , gmax = 0 ;
608+ int best_j = -1 ;
609+ for (int j = 0 ; j < n ; ++ j ) {
610+ float w = qw ? qw [j ] : rmse_type == 1 ? x [j ] * x [j ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [j ]) : sqrtf (fabsf (x [j ]));
611+ int l = L [j ] - nmax ;
612+ float g = scale * w * (x [j ] - scale * l );
613+ if ((g > 0 && l < nmax - 1 ) || (g < 0 && l > - nmax )) {
614+ float ag = fabsf (g );
615+ if (ag > abs_gmax ) {
616+ abs_gmax = ag ; gmax = g ; best_j = j ;
617+ }
618+ }
619+ }
620+ if (best_j < 0 ) break ;
621+
622+ float new_sumlx = sumlx , new_suml2 = suml2 ;
623+ float w = qw ? qw [best_j ] : rmse_type == 1 ? x [best_j ] * x [best_j ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [best_j ]) : sqrtf (fabsf (x [best_j ]));
624+ int l = L [best_j ] - nmax ;
625+ if (gmax > 0 ) {
626+ new_sumlx += w * x [best_j ];
627+ new_suml2 += w * (2 * l + 1 );
628+ l += 1 ;
629+ } else {
630+ new_sumlx -= w * x [best_j ];
631+ new_suml2 -= w * (2 * l - 1 );
632+ l -= 1 ;
633+ }
634+ if (new_suml2 > 0 && new_sumlx * new_sumlx > best * new_suml2 ) {
635+ sumlx = new_sumlx ; suml2 = new_suml2 ;
636+ scale = sumlx /suml2 ; best = scale * sumlx ;
637+ L [best_j ] = l + nmax ;
638+ GGML_ASSERT (L [best_j ] >= 0 && L [best_j ] <= 2 * nmax - 1 );
639+ }
640+ else {
641+ break ;
642+ }
643+
587644 }
588645 return scale ;
589646}
@@ -889,8 +946,9 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
889946 float rmin , float rdelta , int nstep , bool use_mad ) {
890947 float min = x [0 ];
891948 float max = x [0 ];
892- float sum_w = weights ? weights [0 ] : x [0 ]* x [0 ];
893- float sum_x = sum_w * x [0 ];
949+ double sum_w = weights ? (double )weights [0 ] : (double )(x [0 ]* x [0 ]);
950+ double sum_x = sum_w * (double )x [0 ];
951+ double sum_x2 = sum_w * (double )x [0 ] * (double )x [0 ];
894952#ifdef HAVE_BUGGY_APPLE_LINKER
895953 // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
896954 for (volatile int i = 1 ; i < n ; ++ i ) {
@@ -900,8 +958,9 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
900958 if (x [i ] < min ) min = x [i ];
901959 if (x [i ] > max ) max = x [i ];
902960 float w = weights ? weights [i ] : x [i ]* x [i ];
903- sum_w += w ;
904- sum_x += w * x [i ];
961+ sum_w += (double )w ;
962+ sum_x += (double )w * (double )x [i ];
963+ sum_x2 += (double )w * (double )x [i ] * (double )x [i ];
905964 }
906965 if (min > 0 ) {
907966 min = 0 ;
@@ -913,13 +972,13 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
913972 }
914973 float iscale = nmax /(max - min );
915974 float scale = 1 /iscale ;
916- float best_mad = 0 ;
975+ double best_mad = 0 ;
917976 for (int i = 0 ; i < n ; ++ i ) {
918977 int l = nearest_int (iscale * (x [i ] - min ));
919978 L [i ] = MAX (0 , MIN (nmax , l ));
920- float diff = scale * L [i ] + min - x [i ];
921- diff = use_mad ? fabsf (diff ) : diff * diff ;
922- float w = weights ? weights [i ] : x [i ]* x [i ];
979+ double diff = ( double ) scale * L [i ] + ( double ) min - ( double ) x [i ];
980+ diff = use_mad ? fabs (diff ) : diff * diff ;
981+ double w = weights ? ( double ) weights [i ] : ( double )( x [i ]* x [i ]) ;
923982 best_mad += w * diff ;
924983 }
925984 if (nstep < 1 ) {
@@ -928,30 +987,35 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
928987 }
929988 for (int is = 0 ; is <= nstep ; ++ is ) {
930989 iscale = (rmin + rdelta * is + nmax )/(max - min );
931- float sum_l = 0 , sum_l2 = 0 , sum_xl = 0 ;
990+ double sum_l = 0 , sum_l2 = 0 , sum_xl = 0 ;
932991 for (int i = 0 ; i < n ; ++ i ) {
933992 int l = nearest_int (iscale * (x [i ] - min ));
934993 l = MAX (0 , MIN (nmax , l ));
935994 Laux [i ] = l ;
936995 float w = weights ? weights [i ] : x [i ]* x [i ];
937- sum_l += w * l ;
938- sum_l2 += w * l * l ;
939- sum_xl += w * l * x [i ];
996+ sum_l += ( double ) w * l ;
997+ sum_l2 += ( double ) w * l * l ;
998+ sum_xl += ( double ) w * l * ( double ) x [i ];
940999 }
941- float D = sum_w * sum_l2 - sum_l * sum_l ;
1000+ double D = sum_w * sum_l2 - sum_l * sum_l ;
9421001 if (D > 0 ) {
943- float this_scale = (sum_w * sum_xl - sum_x * sum_l )/D ;
944- float this_min = (sum_l2 * sum_x - sum_l * sum_xl )/D ;
1002+ double this_scale = (sum_w * sum_xl - sum_x * sum_l )/D ;
1003+ double this_min = (sum_l2 * sum_x - sum_l * sum_xl )/D ;
9451004 if (this_min > 0 ) {
9461005 this_min = 0 ;
9471006 this_scale = sum_xl / sum_l2 ;
9481007 }
949- float mad = 0 ;
950- for (int i = 0 ; i < n ; ++ i ) {
951- float diff = this_scale * Laux [i ] + this_min - x [i ];
952- diff = use_mad ? fabsf (diff ) : diff * diff ;
953- float w = weights ? weights [i ] : x [i ]* x [i ];
954- mad += w * diff ;
1008+ double mad = 0 ;
1009+ if (use_mad ) {
1010+ for (int i = 0 ; i < n ; ++ i ) {
1011+ double diff = (double )this_scale * Laux [i ] + (double )this_min - (double )x [i ];
1012+ diff = fabs (diff );
1013+ double w = weights ? (double )weights [i ] : (double )(x [i ]* x [i ]);
1014+ mad += w * diff ;
1015+ }
1016+ } else {
1017+ mad = sum_x2 - 2 * this_scale * sum_xl - 2 * this_min * sum_x + 2 * this_scale * this_min * sum_l
1018+ + this_scale * this_scale * sum_l2 + this_min * this_min * sum_w ;
9551019 }
9561020 if (mad < best_mad ) {
9571021 for (int i = 0 ; i < n ; ++ i ) {
@@ -963,6 +1027,57 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
9631027 }
9641028 }
9651029 }
1030+ if (use_mad ) {
1031+ * the_min = - min ;
1032+ return scale ;
1033+ }
1034+
1035+ double sum_l = 0 , sum_l2 = 0 , sum_xl = 0 ;
1036+ for (int i = 0 ; i < n ; ++ i ) {
1037+ int l = L [i ];
1038+ double w = weights ? (double )weights [i ] : (double )(x [i ]* x [i ]);
1039+ sum_l += w * l ;
1040+ sum_l2 += w * l * l ;
1041+ sum_xl += w * l * (double )x [i ];
1042+ }
1043+ double best = 2 * (double )scale * sum_xl + 2 * (double )min * sum_x - 2 * (double )scale * (double )min * sum_l
1044+ - (double )scale * (double )scale * sum_l2 - (double )min * (double )min * sum_w ;
1045+ int last_j = -1 , last_dir = 0 ;
1046+ for (int itry = 0 ; itry < nmax * n ; ++ itry ) {
1047+ float gmax = 0 ;
1048+ int best_j = -1 , dir = 0 ;
1049+ for (int j = 0 ; j < n ; ++ j ) {
1050+ float g = x [j ] - scale * L [j ] - min ;
1051+ if (g > 0 && L [j ] < nmax && g > gmax ) {
1052+ gmax = g ; best_j = j ; dir = 1 ;
1053+ }
1054+ else if (g < 0 && L [j ] > 0 && - g > gmax ) {
1055+ gmax = - g ; best_j = j ; dir = -1 ;
1056+ }
1057+ }
1058+ if (best_j < 0 || (best_j == last_j && dir == - last_dir )) break ;
1059+ double w = weights ? (double )weights [best_j ] : (double )(x [best_j ]* x [best_j ]);
1060+ sum_l += w * dir ;
1061+ sum_l2 += w * (2 * L [best_j ]* dir + 1 );
1062+ sum_xl += w * (double )x [best_j ]* dir ;
1063+ double D = (double )sum_w * sum_l2 - sum_l * sum_l ;
1064+ if (D <= 0 ) break ;
1065+ double this_scale = ((double )sum_w * sum_xl - (double )sum_x * sum_l )/D ;
1066+ double this_min = (sum_l2 * (double )sum_x - sum_l * sum_xl )/D ;
1067+ if (this_min > 0 ) {
1068+ this_min = 0 ;
1069+ this_scale = sum_xl / sum_l2 ;
1070+ }
1071+ if (this_scale < 0 ) break ;
1072+ double score = 2 * this_scale * sum_xl + 2 * this_min * (double )sum_x - 2 * this_scale * this_min * sum_l
1073+ - this_scale * this_scale * sum_l2 - this_min * this_min * (double )sum_w ;
1074+ if (score <= best ) break ;
1075+ best = score ;
1076+ scale = this_scale ;
1077+ min = this_min ;
1078+ L [best_j ] += dir ;
1079+ last_j = best_j ; last_dir = dir ;
1080+ }
9661081 * the_min = - min ;
9671082 return scale ;
9681083}
@@ -1044,7 +1159,7 @@ static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * G
10441159 GGML_ASSERT (quant_weights );
10451160 assert (k % QK_K == 0 );
10461161 const int nb = k / QK_K ;
1047- const bool requantize = true;
1162+ // const bool requantize = true;
10481163
10491164 uint8_t L [QK_K ];
10501165 uint8_t Laux [16 ];
@@ -1058,39 +1173,33 @@ static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * G
10581173 memset (sw , 0 , QK_K /16 * sizeof (float ));
10591174 float sumx2 = 0 ;
10601175 for (int j = 0 ; j < QK_K ; ++ j ) sumx2 += x [j ]* x [j ];
1061- float sigma2 = sumx2 /QK_K ;
1176+ float sigma2 = 0.75f * sumx2 /QK_K ;
10621177 for (int j = 0 ; j < QK_K /16 ; ++ j ) {
10631178 const float * GGML_RESTRICT qw = quant_weights + QK_K * i + 16 * j ;
10641179 for (int l = 0 ; l < 16 ; ++ l ) weight [l ] = qw [l ] * sqrtf (sigma2 + x [16 * j + l ]* x [16 * j + l ]);
10651180 for (int l = 0 ; l < QK_K /16 ; ++ l ) sw [j ] += weight [l ];
10661181 scales [j ] = make_qkx3_quants (16 , 3 , x + 16 * j , weight , L + 16 * j , & mins [j ], Laux , -0.9f , 0.05f , 36 , false);
10671182 }
10681183
1069- float dm , mm ;
1070- dm = make_qp_quants (QK_K /16 , 15 , scales , Ls , sw );
1071- mm = make_qp_quants (QK_K /16 , 15 , mins , Lm , sw );
1184+ float dm = make_qp_quants (QK_K /16 , 15 , scales , Ls , sw );
1185+ float mm = make_qp_quants (QK_K /16 , 15 , mins , Lm , sw );
10721186
10731187 y [i ].d = GGML_FP32_TO_FP16 (dm );
10741188 y [i ].dmin = GGML_FP32_TO_FP16 (mm );
1075- dm = GGML_FP16_TO_FP32 (y [i ].d );
1076- mm = GGML_FP16_TO_FP32 (y [i ].dmin );
10771189
10781190 for (int j = 0 ; j < QK_K /16 ; ++ j ) {
1079- y [i ].scales [j ] = Ls [j ] | (Lm [j ] << 4 );
1080- }
1081-
1082- if (requantize ) {
1083- for (int j = 0 ; j < QK_K /16 ; ++ j ) {
1084- const float d = dm * (y [i ].scales [j ] & 0xF );
1085- if (!d ) continue ;
1086- const float m = mm * (y [i ].scales [j ] >> 4 );
1087- for (int ii = 0 ; ii < 16 ; ++ ii ) {
1088- int l = nearest_int ((x [16 * j + ii ] + m )/d );
1089- l = MAX (0 , MIN (3 , l ));
1090- L [16 * j + ii ] = l ;
1091- }
1191+ float d = dm * Ls [j ];
1192+ float m = mm * Lm [j ];
1193+ float id = d ? 1 /d : 0.f ;
1194+ for (int l = 0 ; l < QK_K /16 ; ++ l ) {
1195+ int q = nearest_int ((x [16 * j + l ] + m )* id );
1196+ q = MAX (0 , MIN (3 , q ));
1197+ L [16 * j + l ] = q ;
10921198 }
10931199 }
1200+ for (int j = 0 ; j < QK_K /16 ; ++ j ) {
1201+ y [i ].scales [j ] = Ls [j ] | (Lm [j ] << 4 );
1202+ }
10941203
10951204 for (int j = 0 ; j < QK_K ; j += 128 ) {
10961205 for (int l = 0 ; l < 32 ; ++ l ) {
@@ -1979,8 +2088,12 @@ static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * G
19792088 const int64_t nb = n_per_row /QK4_0 ;
19802089 for (int ib = 0 ; ib < nb ; ++ ib ) {
19812090 const float * xb = x + QK4_0 * ib ;
1982- const float * qw = quant_weights + QK4_0 * ib ;
1983- for (int j = 0 ; j < QK4_0 ; ++ j ) weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]);
2091+ if (quant_weights ) {
2092+ const float * qw = quant_weights + QK4_0 * ib ;
2093+ for (int j = 0 ; j < QK4_0 ; ++ j ) weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]);
2094+ } else {
2095+ for (int j = 0 ; j < QK4_0 ; ++ j ) weight [j ] = xb [j ]* xb [j ];
2096+ }
19842097 float d = make_qx_quants (QK4_0 , 8 , xb , L , 1 , weight );
19852098 y [ib ].d = GGML_FP32_TO_FP16 (d );
19862099 for (int j = 0 ; j < 16 ; ++ j ) {
@@ -4877,6 +4990,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
48774990 }
48784991 d = sumqx /sumq2 ;
48794992 float best = d * sumqx ;
4993+ float best_sumqx = sumqx , best_sumq2 = sumq2 ;
48804994 for (int itry = - ntry ; itry <= ntry ; ++ itry ) {
48814995 id = (itry + values [0 ])/max ;
48824996 sumqx = sumq2 = 0 ;
@@ -4890,8 +5004,68 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
48905004 }
48915005 if (sumq2 > 0 && sumqx * sumqx > best * sumq2 ) {
48925006 d = sumqx /sumq2 ; best = d * sumqx ;
5007+ best_sumqx = sumqx ; best_sumq2 = sumq2 ;
5008+ for (int j = 0 ; j < block_size ; ++ j ) {
5009+ float al = id * xb [j ];
5010+ Lb [j ] = best_index_iq4nl (values , al );
5011+ }
5012+ }
5013+ id = (itry + values [15 ])/max ;
5014+ sumqx = sumq2 = 0 ;
5015+ for (int j = 0 ; j < block_size ; ++ j ) {
5016+ float al = id * xb [j ];
5017+ int l = best_index_iq4nl (values , al );
5018+ float q = values [l ];
5019+ float w = weight [j ];
5020+ sumqx += w * q * xb [j ];
5021+ sumq2 += w * q * q ;
5022+ }
5023+ if (sumq2 > 0 && sumqx * sumqx > best * sumq2 ) {
5024+ d = sumqx /sumq2 ; best = d * sumqx ;
5025+ best_sumqx = sumqx ; best_sumq2 = sumq2 ;
5026+ for (int j = 0 ; j < block_size ; ++ j ) {
5027+ float al = id * xb [j ];
5028+ Lb [j ] = best_index_iq4nl (values , al );
5029+ }
5030+ }
5031+ }
5032+ sumqx = best_sumqx ; sumq2 = best_sumq2 ;
5033+ best_sumqx = sumqx ; best_sumq2 = sumq2 ;
5034+ for (int iter = 0 ; iter < 32 * block_size ; ++ iter ) {
5035+ float min_step = INFINITY ;
5036+ int best_j = -1 ; int dir = 0 ;
5037+ for (int j = 0 ; j < block_size ; ++ j ) {
5038+ float w = weight [j ];
5039+ float g = d * w * (xb [j ] - d * values [Lb [j ]]);
5040+ if (g > 0 && Lb [j ] < 15 ) {
5041+ float step = (values [Lb [j ]+ 1 ] - values [Lb [j ]])/g ;
5042+ if (step < min_step ) {
5043+ min_step = step ; best_j = j ; dir = 1 ;
5044+ }
5045+ }
5046+ else if (g < 0 && Lb [j ] > 0 ) {
5047+ float step = (values [Lb [j ]- 1 ] - values [Lb [j ]])/g ;
5048+ if (step < min_step ) {
5049+ min_step = step ; best_j = j ; dir = -1 ;
5050+ }
5051+ }
5052+ }
5053+ if (best_j < 0 ) break ;
5054+
5055+ float new_sumqx = sumqx , new_sumq2 = sumq2 ;
5056+ float w = weight [best_j ];
5057+ new_sumqx += w * xb [best_j ]* (values [Lb [best_j ]+ dir ] - values [Lb [best_j ]]);
5058+ new_sumq2 += w * (values [Lb [best_j ]+ dir ]* values [Lb [best_j ]+ dir ] - values [Lb [best_j ]]* values [Lb [best_j ]]);
5059+ if (new_sumq2 > 0 && new_sumqx * new_sumqx > best * new_sumq2 ) {
5060+ sumqx = new_sumqx ; sumq2 = new_sumq2 ;
5061+ d = sumqx /sumq2 ; best = d * sumqx ;
5062+ Lb [best_j ] += dir ;
5063+ }
5064+ else {
5065+ break ;
48935066 }
48945067 }
5068+
48955069 scales [ib ] = d ;
48965070 float abs_d = fabsf (d );
48975071 if (abs_d > amax_scale ) {
0 commit comments