@@ -1768,10 +1768,8 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
17681768    float scale = suml2 ? sumlx/suml2 : 0.0f;
17691769    if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
17701770    float best = scale * sumlx;
1771+     float best_sumlx = sumlx, best_suml2 = suml2;
17711772    for (int is = -9; is <= 9; ++is) {
1772-         if (is == 0) {
1773-             continue;
1774-         }
17751773        iscale = -(nmax + 0.1f*is) / max;
17761774        sumlx = suml2 = 0;
17771775        for (int i = 0; i < n; ++i) {
@@ -1787,7 +1785,66 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
17871785                L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
17881786            }
17891787            scale = sumlx/suml2; best = scale*sumlx;
1788+             best_sumlx = sumlx; best_suml2 = suml2;
1789+         }
1790+         iscale = (nmax-1 + 0.1f*is) / max;
1791+         sumlx = suml2 = 0;
1792+         for (int i = 0; i < n; ++i) {
1793+             int l = nearest_int(iscale * x[i]);
1794+             l = MAX(-nmax, MIN(nmax-1, l));
1795+             float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
1796+             sumlx += w*x[i]*l;
1797+             suml2 += w*l*l;
17901798        }
1799+         if (suml2 > 0 && sumlx*sumlx > best*suml2) {
1800+             for (int i = 0; i < n; ++i) {
1801+                 int l = nearest_int(iscale * x[i]);
1802+                 L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
1803+             }
1804+             scale = sumlx/suml2; best = scale*sumlx;
1805+             best_sumlx = sumlx; best_suml2 = suml2;
1806+         }
1807+     }
1808+ 
1809+     sumlx = best_sumlx; suml2 = best_suml2;
1810+     for (int iter = 0; iter < n*(2*nmax-1); ++iter) {
1811+         float abs_gmax = 0, gmax = 0;
1812+         int best_j = -1;
1813+         for (int j = 0; j < n; ++j) {
1814+             float w = qw ? qw[j] : rmse_type == 1 ? x[j] * x[j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[j]) : sqrtf(fabsf(x[j]));
1815+             int l = L[j] - nmax;
1816+             float g = scale * w * (x[j] - scale*l);
1817+             if ((g > 0 && l < nmax-1) || (g < 0 && l > -nmax)) {
1818+                 float ag = fabsf(g);
1819+                 if (ag > abs_gmax) {
1820+                     abs_gmax = ag; gmax = g; best_j = j;
1821+                 }
1822+             }
1823+         }
1824+         if (best_j < 0) break;
1825+ 
1826+         float new_sumlx = sumlx, new_suml2 = suml2;
1827+         float w = qw ? qw[best_j] : rmse_type == 1 ? x[best_j] * x[best_j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[best_j]) : sqrtf(fabsf(x[best_j]));
1828+         int l = L[best_j] - nmax;
1829+         if (gmax > 0) {
1830+             new_sumlx += w*x[best_j];
1831+             new_suml2 += w*(2*l + 1);
1832+             l += 1;
1833+         } else {
1834+             new_sumlx -= w*x[best_j];
1835+             new_suml2 -= w*(2*l - 1);
1836+             l -= 1;
1837+         }
1838+         if (new_suml2 > 0 && new_sumlx*new_sumlx > best*new_suml2) {
1839+             sumlx = new_sumlx; suml2 = new_suml2;
1840+             scale = sumlx/suml2; best = scale*sumlx;
1841+             L[best_j] = l + nmax;
1842+             GGML_ASSERT(L[best_j] >= 0 && L[best_j] <= 2*nmax-1);
1843+         }
1844+         else {
1845+             break;
1846+         }
1847+ 
17911848    }
17921849    return scale;
17931850}
@@ -3254,8 +3311,12 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
32543311    const int64_t nb = n_per_row/QK4_0;
32553312    for (int ib = 0; ib < nb; ++ib) {
32563313        const float * xb = x + QK4_0 * ib;
3257-         const float * qw = quant_weights + QK4_0 * ib;
3258-         for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
3314+         if (quant_weights) {
3315+             const float * qw = quant_weights + QK4_0 * ib;
3316+             for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
3317+         } else {
3318+             for (int j = 0; j < QK4_0; ++j) weight[j] = xb[j]*xb[j];
3319+         }
32593320        float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);
32603321        y[ib].d = GGML_FP32_TO_FP16(d);
32613322        for (int j = 0; j < 16; ++j) {
@@ -14581,6 +14642,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
1458114642        }
1458214643        d = sumqx/sumq2;
1458314644        float best = d*sumqx;
14645+         float best_sumqx = sumqx, best_sumq2 = sumq2;
1458414646        for (int itry = -ntry; itry <= ntry; ++itry) {
1458514647            id = (itry + values[0])/max;
1458614648            sumqx = sumq2 = 0;
@@ -14594,8 +14656,67 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
1459414656            }
1459514657            if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
1459614658                d = sumqx/sumq2; best = d * sumqx;
14659+                 best_sumqx = sumqx; best_sumq2 = sumq2;
14660+                 for (int j = 0; j < block_size; ++j) {
14661+                     float al = id*xb[j];
14662+                     Lb[j] = best_index_iq4nl(values, al);
14663+                 }
14664+             }
14665+             id = (itry + values[15])/max;
14666+             sumqx = sumq2 = 0;
14667+             for (int j = 0; j < block_size; ++j) {
14668+                 float al = id*xb[j];
14669+                 int l = best_index_iq4nl(values, al);
14670+                 float q = values[l];
14671+                 float w = weight[j];
14672+                 sumqx += w*q*xb[j];
14673+                 sumq2 += w*q*q;
14674+             }
14675+             if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
14676+                 d = sumqx/sumq2; best = d * sumqx;
14677+                 best_sumqx = sumqx; best_sumq2 = sumq2;
14678+                 for (int j = 0; j < block_size; ++j) {
14679+                     float al = id*xb[j];
14680+                     Lb[j] = best_index_iq4nl(values, al);
14681+                 }
1459714682            }
1459814683        }
14684+         sumqx = best_sumqx; sumq2 = best_sumq2;
14685+         for (int iter = 0; iter < 32*block_size; ++iter) {
14686+             float min_step = INFINITY;
14687+             int best_j = -1; int dir = 0;
14688+             for (int j = 0; j < block_size; ++j) {
14689+                 float w = weight[j];
14690+                 float g = d * w * (xb[j] - d*values[Lb[j]]);
14691+                 if (g > 0 && Lb[j] < 15) {
14692+                     float step = (values[Lb[j]+1] - values[Lb[j]])/g;
14693+                     if (step < min_step) {
14694+                         min_step = step; best_j = j; dir = 1;
14695+                     }
14696+                 }
14697+                 else if (g < 0 && Lb[j] > 0) {
14698+                     float step = (values[Lb[j]-1] - values[Lb[j]])/g;
14699+                     if (step < min_step) {
14700+                         min_step = step; best_j = j; dir = -1;
14701+                     }
14702+                 }
14703+             }
14704+             if (best_j < 0) break;
14705+ 
14706+             float new_sumqx = sumqx, new_sumq2 = sumq2;
14707+             float w = weight[best_j];
14708+             new_sumqx += w*xb[best_j]*(values[Lb[best_j]+dir] - values[Lb[best_j]]);
14709+             new_sumq2 += w*(values[Lb[best_j]+dir]*values[Lb[best_j]+dir] - values[Lb[best_j]]*values[Lb[best_j]]);
14710+             if (new_sumq2 > 0 && new_sumqx*new_sumqx > best*new_sumq2) {
14711+                 sumqx = new_sumqx; sumq2 = new_sumq2;
14712+                 d = sumqx/sumq2; best = d*sumqx;
14713+                 Lb[best_j] += dir;
14714+             }
14715+             else {
14716+                 break;
14717+             }
14718+         }
14719+ 
1459914720        scales[ib] = d;
1460014721        float abs_d = fabsf(d);
1460114722        if (abs_d > amax_scale) {
0 commit comments