@@ -844,8 +844,7 @@ static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block s
844844static void quantize_row_q4_0_reference (const float * restrict x , block_q4_0 * restrict y , int k ) {
845845 static const int qk = QK4_0 ;
846846
847- assert (qk / 16 == 0 );
848- assert ( k % qk == 0 );
847+ assert (k % qk == 0 );
849848
850849 const int nb = k / qk ;
851850
@@ -866,20 +865,16 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
866865
867866 y [i ].d = d ;
868867
869- uint64_t qs [QK4_0 / 16 ] = {0 };
870-
871868 for (int l = 0 ; l < qk /2 ; ++ l ) {
872869 const float x0 = x [i * qk + 0 + l ]* id ;
873870 const float x1 = x [i * qk + qk /2 + l ]* id ;
874871
875- const uint64_t xi0 = MIN (15 , (int8_t )(x0 + 8.5f ));
876- const uint64_t xi1 = MIN (15 , (int8_t )(x1 + 8.5f ));
872+ const uint8_t xi0 = MIN (15 , (int8_t )(x0 + 8.5f ));
873+ const uint8_t xi1 = MIN (15 , (int8_t )(x1 + 8.5f ));
877874
878- qs [l / 8 ] | = xi0 << ( 8 * ( l & 7 )) ;
879- qs [l / 8 ] |= xi1 << ( 8 * ( l & 7 ) + 4 ) ;
875+ y [ i ]. qs [l ] = xi0 ;
876+ y [ i ]. qs [l ] |= xi1 << 4 ;
880877 }
881-
882- memcpy (y [i ].qs , qs , qk /2 );
883878 }
884879}
885880
@@ -890,8 +885,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k
890885static void quantize_row_q4_1_reference (const float * restrict x , block_q4_1 * restrict y , int k ) {
891886 const int qk = QK4_1 ;
892887
893- assert (qk / 16 == 0 );
894- assert ( k % qk == 0 );
888+ assert (k % qk == 0 );
895889
896890 const int nb = k / qk ;
897891
@@ -912,20 +906,16 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
912906 y [i ].d = d ;
913907 y [i ].m = min ;
914908
915- uint64_t qs [QK4_1 / 16 ] = {0 };
916-
917909 for (int l = 0 ; l < qk /2 ; ++ l ) {
918910 const float x0 = (x [0 + l ] - min )* id ;
919911 const float x1 = (x [qk /2 + l ] - min )* id ;
920912
921- const uint64_t xi0 = MIN (15 , (int8_t )(x0 + 0.5f ));
922- const uint64_t xi1 = MIN (15 , (int8_t )(x1 + 0.5f ));
913+ const uint8_t xi0 = MIN (15 , (int8_t )(x0 + 0.5f ));
914+ const uint8_t xi1 = MIN (15 , (int8_t )(x1 + 0.5f ));
923915
924- qs [l / 8 ] | = xi0 << ( 8 * ( l & 7 )) ;
925- qs [l / 8 ] |= xi1 << ( 8 * ( l & 7 ) + 4 ) ;
916+ y [ i ]. qs [l ] = xi0 ;
917+ y [ i ]. qs [l ] |= xi1 << 4 ;
926918 }
927-
928- memcpy (y [i ].qs , qs , qk /2 );
929919 }
930920}
931921
@@ -937,8 +927,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k
937927static void quantize_row_q4_2_reference (const float * restrict x , block_q4_2 * restrict y , int k ) {
938928 static const int qk = QK4_2 ;
939929
940- assert (qk / 16 == 0 );
941- assert ( k % qk == 0 );
930+ assert (k % qk == 0 );
942931
943932 const int nb = k / qk ;
944933
@@ -983,8 +972,7 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict y, int k
983972static void quantize_row_q5_0_reference (const float * restrict x , block_q5_0 * restrict y , int k ) {
984973 static const int qk = QK5_0 ;
985974
986- assert (qk / 16 == 0 );
987- assert ( k % qk == 0 );
975+ assert (k % qk == 0 );
988976
989977 const int nb = k / qk ;
990978
@@ -1006,24 +994,21 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
1006994 y [i ].d = d ;
1007995
1008996 uint32_t qh = 0 ;
1009- uint64_t qs [QK5_0 / 16 ] = {0 };
1010997
1011998 for (int l = 0 ; l < qk /2 ; ++ l ) {
1012999 const float x0 = x [i * qk + 0 + l ]* id ;
10131000 const float x1 = x [i * qk + qk /2 + l ]* id ;
10141001
1015- const uint64_t xi0 = MIN (31 , (int8_t )(x0 + 16.5f ));
1016- const uint64_t xi1 = MIN (31 , (int8_t )(x1 + 16.5f ));
1002+ const uint8_t xi0 = MIN (31 , (int8_t )(x0 + 16.5f ));
1003+ const uint8_t xi1 = MIN (31 , (int8_t )(x1 + 16.5f ));
10171004
1018- qs [l /8 ] |= xi0 << (8 * (l & 7 ));
1019- qs [l /8 ] |= xi1 << (8 * (l & 7 ) + 4 );
1005+ y [i ].qs [l ] = (xi0 & 0x0F ) | ((xi1 & 0x0F ) << 4 );
10201006
10211007 // get the 5-th bit and store it in qh at the right position
10221008 qh |= ((xi0 & 0x10 ) >> 4 ) << (l + 0 );
10231009 qh |= ((xi1 & 0x10 ) >> 4 ) << (l + qk /2 );
10241010 }
10251011
1026- memcpy ( y [i ].qs , qs , qk /2 );
10271012 memcpy (& y [i ].qh , & qh , sizeof (qh ));
10281013 }
10291014}
@@ -1033,50 +1018,50 @@ static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k
10331018}
10341019
10351020static void quantize_row_q5_1_reference (const float * restrict x , block_q5_1 * restrict y , int k ) {
1036- assert (k % QK5_1 == 0 );
1037- const int nb = k / QK5_1 ;
1021+ const int qk = QK5_1 ;
1022+
1023+ assert (k % qk == 0 );
1024+
1025+ const int nb = k / qk ;
10381026
10391027 for (int i = 0 ; i < nb ; i ++ ) {
10401028 float min = FLT_MAX ;
10411029 float max = - FLT_MAX ;
10421030
1043- for (int l = 0 ; l < QK5_1 ; l ++ ) {
1044- const float v = x [i * QK5_1 + l ];
1031+ for (int l = 0 ; l < qk ; l ++ ) {
1032+ const float v = x [i * qk + l ];
1033+
10451034 if (v < min ) min = v ;
10461035 if (v > max ) max = v ;
10471036 }
10481037
1049- const float d = (max - min ) / ((1 << 5 ) - 1 );
1038+ const float d = (max - min ) / ((1 << 5 ) - 1 );
10501039 const float id = d ? 1.0f /d : 0.0f ;
10511040
10521041 y [i ].d = GGML_FP32_TO_FP16 (d );
10531042 y [i ].m = GGML_FP32_TO_FP16 (min );
10541043
10551044 uint32_t qh = 0 ;
10561045
1057- for (int l = 0 ; l < QK5_1 ; l += 2 ) {
1058- const float v0 = (x [i * QK5_1 + l + 0 ] - min )* id ;
1059- const float v1 = (x [i * QK5_1 + l + 1 ] - min )* id ;
1046+ for (int l = 0 ; l < qk / 2 ; ++ l ) {
1047+ const float x0 = (x [i * qk + 0 + l ] - min )* id ;
1048+ const float x1 = (x [i * qk + qk / 2 + l ] - min )* id ;
10601049
1061- const uint32_t vi0 = (int ) ( v0 + 0.5f );
1062- const uint32_t vi1 = (int ) ( v1 + 0.5f );
1050+ const uint8_t xi0 = (uint8_t )( x0 + 0.5f );
1051+ const uint8_t xi1 = (uint8_t )( x1 + 0.5f );
10631052
1064- y [i ].qs [l / 2 ] = (vi0 & 0x0F ) | ((vi1 & 0x0F ) << 4 );
1053+ y [i ].qs [l ] = (xi0 & 0x0F ) | ((xi1 & 0x0F ) << 4 );
10651054
10661055 // get the 5-th bit and store it in qh at the right position
1067- qh |= ((vi0 & 0x10 ) >> 4 ) << (l + 0 );
1068- qh |= ((vi1 & 0x10 ) >> 4 ) << (l + 1 );
1056+ qh |= ((xi0 & 0x10 ) >> 4 ) << (l + 0 );
1057+ qh |= ((xi1 & 0x10 ) >> 4 ) << (l + qk / 2 );
10691058 }
10701059
10711060 memcpy (& y [i ].qh , & qh , sizeof (y [i ].qh ));
10721061 }
10731062}
10741063
1075- static void quantize_row_q5_1 (const float * restrict x , void * restrict vy , int k ) {
1076- assert (k % QK5_1 == 0 );
1077-
1078- block_q5_1 * restrict y = vy ;
1079-
1064+ static void quantize_row_q5_1 (const float * restrict x , void * restrict y , int k ) {
10801065 quantize_row_q5_1_reference (x , y , k );
10811066}
10821067
@@ -1316,8 +1301,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
13161301static void dequantize_row_q4_0 (const block_q4_0 * restrict x , float * restrict y , int k ) {
13171302 static const int qk = QK4_0 ;
13181303
1319- assert (qk / 16 == 0 );
1320- assert ( k % qk == 0 );
1304+ assert (k % qk == 0 );
13211305
13221306 const int nb = k / qk ;
13231307
@@ -1337,8 +1321,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
13371321static void dequantize_row_q4_1 (const block_q4_1 * restrict x , float * restrict y , int k ) {
13381322 static const int qk = QK4_1 ;
13391323
1340- assert (qk / 16 == 0 );
1341- assert ( k % qk == 0 );
1324+ assert (k % qk == 0 );
13421325
13431326 const int nb = k / qk ;
13441327
@@ -1360,8 +1343,7 @@ static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict
13601343 // BORKEN !!!
13611344 static const int qk = QK4_2 ;
13621345
1363- assert (qk / 16 == 0 );
1364- assert ( k % qk == 0 );
1346+ assert (k % qk == 0 );
13651347
13661348 const int nb = k / qk ;
13671349
@@ -1381,8 +1363,7 @@ static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict
13811363static void dequantize_row_q5_0 (const block_q5_0 * restrict x , float * restrict y , int k ) {
13821364 static const int qk = QK4_0 ;
13831365
1384- assert (qk / 16 == 0 );
1385- assert ( k % qk == 0 );
1366+ assert (k % qk == 0 );
13861367
13871368 const int nb = k / qk ;
13881369
@@ -1405,39 +1386,29 @@ static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict
14051386 }
14061387}
14071388
1408- static void dequantize_row_q5_1 (const void * restrict vx , float * restrict y , int k ) {
1409- assert (k % QK5_1 == 0 );
1410- const int nb = k / QK5_1 ;
1389+ static void dequantize_row_q5_1 (const block_q5_1 * restrict x , float * restrict y , int k ) {
1390+ static const int qk = QK5_1 ;
14111391
1412- const block_q5_1 * restrict x = vx ;
1392+ assert (k % qk == 0 );
1393+
1394+ const int nb = k / qk ;
14131395
14141396 for (int i = 0 ; i < nb ; i ++ ) {
14151397 const float d = GGML_FP16_TO_FP32 (x [i ].d );
14161398 const float m = GGML_FP16_TO_FP32 (x [i ].m );
14171399
1418- const uint8_t * restrict pp = x [i ].qs ;
1419-
14201400 uint32_t qh ;
14211401 memcpy (& qh , x [i ].qh , sizeof (qh ));
14221402
1423- for (int l = 0 ; l < QK5_1 ; l += 2 ) {
1424- const uint8_t vi = pp [l /2 ];
1425-
1426- // extract the 5-th bit from qh
1427- const uint8_t vh0 = ((qh & (1u << (l + 0 ))) >> (l + 0 )) << 4 ;
1428- const uint8_t vh1 = ((qh & (1u << (l + 1 ))) >> (l + 1 )) << 4 ;
1429-
1430- const uint8_t vi0 = (vi & 0x0F ) | vh0 ;
1431- const uint8_t vi1 = (vi >> 4 ) | vh1 ;
1432-
1433- const float v0 = vi0 * d + m ;
1434- const float v1 = vi1 * d + m ;
1403+ for (int j = 0 ; j < qk /2 ; ++ j ) {
1404+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4 ;
1405+ const uint8_t xh_1 = ((qh & (1u << (j + 16 ))) >> (j + 12 ));
14351406
1436- y [ i * QK5_1 + l + 0 ] = v0 ;
1437- y [ i * QK5_1 + l + 1 ] = v1 ;
1407+ const int x0 = ( x [ i ]. qs [ j ] & 0xf ) | xh_0 ;
1408+ const int x1 = ( x [ i ]. qs [ j ] >> 4 ) | xh_1 ;
14381409
1439- assert (! isnan ( y [i * QK5_1 + l + 0 ])) ;
1440- assert (! isnan ( y [i * QK5_1 + l + 1 ])) ;
1410+ y [i * qk + j + 0 ] = x0 * d + m ;
1411+ y [i * qk + j + qk / 2 ] = x1 * d + m ;
14411412 }
14421413 }
14431414}
@@ -1500,7 +1471,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
15001471 .vec_dot_type = GGML_TYPE_Q8_0 ,
15011472 },
15021473 [GGML_TYPE_Q5_1 ] = {
1503- .dequantize_row_q = dequantize_row_q5_1 ,
1474+ .dequantize_row_q = ( dequantize_row_q_t ) dequantize_row_q5_1 ,
15041475 .quantize_row_q = quantize_row_q5_1 ,
15051476 .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q5_1_reference ,
15061477 .quantize_row_q_dot = quantize_row_q8_1 ,
@@ -2748,11 +2719,12 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
27482719}
27492720
27502721static void ggml_vec_dot_q5_1_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2751- const int nb = n / QK8_1 ;
2722+ const int qk = QK8_1 ;
2723+ const int nb = n / qk ;
27522724
2753- assert (n % QK8_1 == 0 );
2725+ assert (n % qk == 0 );
27542726 assert (nb % 2 == 0 );
2755- assert (QK8_1 == QK5_1 );
2727+ assert (qk == QK5_1 );
27562728
27572729 const block_q5_1 * restrict x = vx ;
27582730 const block_q8_1 * restrict y = vy ;
@@ -2788,13 +2760,9 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
27882760 const int8x16_t v0l = vreinterpretq_s8_u8 (vandq_u8 (v0 , vdupq_n_u8 (0x0F )));
27892761 const int8x16_t v0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0 , 4 ));
27902762
2791- // interleave
2792- const int8x16_t v0lz = vzip1q_s8 (v0l , v0h );
2793- const int8x16_t v0hz = vzip2q_s8 (v0l , v0h );
2794-
27952763 // add
2796- const int8x16_t v0lf = vorrq_s8 (v0lz , qhl );
2797- const int8x16_t v0hf = vorrq_s8 (v0hz , qhh );
2764+ const int8x16_t v0lf = vorrq_s8 (v0l , qhl );
2765+ const int8x16_t v0hf = vorrq_s8 (v0h , qhh );
27982766
27992767 // load y
28002768 const int8x16_t v1l = vld1q_s8 (y0 -> qs );
@@ -2917,36 +2885,28 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
29172885
29182886 * s = hsum_float_8 (acc ) + summs ;
29192887#else
2888+ // scalar
29202889 float sumf = 0.0 ;
29212890
29222891 for (int i = 0 ; i < nb ; i ++ ) {
2923- const uint8_t * restrict x0 = x [i ].qs ;
2924- const int8_t * restrict y0 = y [i ].qs ;
2892+ const int8_t * py = y [i ].qs ;
29252893
29262894 uint32_t qh ;
29272895 memcpy (& qh , x [i ].qh , sizeof (qh ));
29282896
2929- const float d = GGML_FP16_TO_FP32 (x [i ].d );
2930- const float m = GGML_FP16_TO_FP32 (x [i ].m );
2931-
2932- int sxy = 0 ;
2933-
2934- for (int j = 0 ; j < QK8_1 /2 ; j ++ ) {
2935- const uint8_t v0 = x0 [j ];
2936-
2937- const int x0_0h = ((qh & (1u << (2 * j + 0 ))) >> (2 * j + 0 )) << 4 ;
2938- const int x1_0h = ((qh & (1u << (2 * j + 1 ))) >> (2 * j + 1 )) << 4 ;
2897+ int sumi = 0 ;
29392898
2940- const int x0_0 = (v0 & 0x0F ) | x0_0h ;
2941- const int x1_0 = (v0 >> 4 ) | x1_0h ;
2899+ for (int j = 0 ; j < qk /2 ; ++ j ) {
2900+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4 ;
2901+ const uint8_t xh_1 = ((qh & (1u << (j + 16 ))) >> (j + 12 ));
29422902
2943- const int y0_0 = y0 [ 2 * j + 0 ] ;
2944- const int y1_0 = y0 [ 2 * j + 1 ] ;
2903+ const int32_t x0 = ( x [ i ]. qs [ j ] & 0xF ) | xh_0 ;
2904+ const int32_t x1 = ( x [ i ]. qs [ j ] >> 4 ) | xh_1 ;
29452905
2946- sxy += x0_0 * y0_0 + x1_0 * y1_0 ;
2906+ sumi += ( x0 * py [ j ]) + ( x1 * py [ j + qk / 2 ]) ;
29472907 }
29482908
2949- sumf += (d * sxy )* y [i ].d + m * (y [i ].s0 + y [i ].s1 );
2909+ sumf += (GGML_FP16_TO_FP32 ( x [ i ]. d )* y [i ].d ) * sumi + GGML_FP16_TO_FP32 ( x [ i ]. m ) * (y [i ].s0 + y [i ].s1 );
29502910 }
29512911
29522912 * s = sumf ;
0 commit comments