diff --git a/.gitignore b/.gitignore index 417e591db6d..5ee2f4499cc 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,7 @@ a.out.* AGENTS.local.md .pi/SYSTEM.md + + +/models +/model_zoo \ No newline at end of file diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3357a0d9985..f6470b95574 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -429,7 +429,9 @@ extern "C" { GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) GGML_TYPE_Q1_0 = 41, - GGML_TYPE_COUNT = 42, + GGML_TYPE_STQ1_0 = 42, + + GGML_TYPE_COUNT = 43, }; // precision diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index f05683b44cd..f3fda4088d6 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -277,6 +277,14 @@ typedef struct { } block_tq2_0; static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding"); +// 1.3125 bpw +typedef struct { + uint8_t qs[QK_K/8]; // 4-bit code per group of 4 + uint8_t sign[QK_K/32]; // 1-bit table select per group of 4 + ggml_half d; // scale +} block_stq1_0; +static_assert(sizeof(block_stq1_0) == sizeof(ggml_half) + QK_K / 8 + QK_K / 32, "wrong stq1_0 block size/padding"); + // // Super-block quantization structures // @@ -496,6 +504,79 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_ #if defined(GGML_COMMON_IMPL) +// STQ1_0 codebook: index = (sign << 4) | slot -> packed 4-lane ternary pattern. +// +// Each STQ1_0 group has 4 lanes with exactly one 0 and three non-zero (+1/-1) of +// equal magnitude — 4 zero-positions × 2^3 signs = 32 patterns. They are split +// into 4-bit slot + 1-bit sign. +// +// Per-lane encoding (2 bits): -1 -> 0b00, 0 -> 0b01, +1 -> 0b10. A pattern +// byte is qpack = lane3<<6 | lane2<<4 | lane1<<2 | lane0. +// +// Slot (low nibble): pattern with the first non-zero lane fixed to +1. +// slot = (zero_pos << 2) | tail_bits +// zero_pos -- index of the zero lane (0..3) +// tail_bits.0 -- sign of 2nd non-zero lane (0 => +1, 1 => -1) +// tail_bits.1 -- sign of 3rd non-zero lane (0 => +1, 1 => -1) +// Sign bit (bit 4): global flip of every non-zero lane (0 unchanged). +// +// Worked example, slot 0 (zero_pos=0, tail=0b00), sign=0: +// lanes = (0, +1, +1, +1) -> qpack = 10_10_10_01 = 0xA9 +// Sign=1 of the same slot flips all non-zero lanes: +// lanes = (0, -1, -1, -1) -> qpack = 00_00_00_01 = 0x01 +// +// The sign=1 half is precomputed so decode is a single load: +// qpack = stq1_0_codebook[(sign << 4) | slot] +GGML_TABLE_BEGIN(uint8_t, stq1_0_codebook, 32) + // sign = 0 (first non-zero lane is +1) + 0xA9, 0x89, 0x29, 0x09, 0xA6, 0x86, 0x26, 0x06, + 0x9A, 0x92, 0x1A, 0x12, 0x6A, 0x62, 0x4A, 0x42, + // sign = 1 (every non-zero lane negated) + 0x01, 0x21, 0x81, 0xA1, 0x04, 0x24, 0x84, 0xA4, + 0x10, 0x18, 0x90, 0x98, 0x40, 0x48, 0x60, 0x68, +GGML_TABLE_END() + +// Reverse maps for the encoder: qpack byte -> (slot, sign). Derived from +// stq1_0_codebook above; entries for non-codebook qpack values are 0xFF (slot) +// / 0 (sign), which the encoder asserts against. +GGML_TABLE_BEGIN(uint8_t, stq1_0_qpack_to_slot, 256) + 0xFF, 0x00, 0xFF, 0xFF, 0x04, 0xFF, 0x07, 0xFF, 0xFF, 0x03, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x08, 0xFF, 0x0B, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x09, 0xFF, 0x0A, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0x01, 0xFF, 0xFF, 0x05, 0xFF, 0x06, 0xFF, 0xFF, 0x02, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x0C, 0xFF, 0x0F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x0D, 0xFF, 0x0E, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x0E, 0xFF, 0x0D, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x0F, 0xFF, 0x0C, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0x02, 0xFF, 0xFF, 0x06, 0xFF, 0x05, 0xFF, 0xFF, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x0A, 0xFF, 0x09, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x0B, 0xFF, 0x08, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0x03, 0xFF, 0xFF, 0x07, 0xFF, 0x04, 0xFF, 0xFF, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint8_t, stq1_0_qpack_to_sign, 256) + 0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +GGML_TABLE_END() + GGML_TABLE_BEGIN(uint8_t, kmask_iq2xs, 8) 1, 2, 4, 8, 16, 32, 64, 128 GGML_TABLE_END() diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index b0391a67c88..dddcc712871 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -19,6 +19,7 @@ #define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K #define ggml_vec_dot_q3_K_q8_K_generic ggml_vec_dot_q3_K_q8_K #define ggml_vec_dot_q4_K_q8_K_generic ggml_vec_dot_q4_K_q8_K @@ -83,6 +84,7 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 @@ -116,6 +118,7 @@ #define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -159,6 +162,7 @@ #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 @@ -203,6 +207,8 @@ #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 +#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -246,6 +252,7 @@ #define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K #define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K #define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K @@ -296,6 +303,7 @@ #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K #define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K #define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K #define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index fe621332970..9316ef054de 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -1608,6 +1608,128 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +// repack q8_K qs from contiguous to planar layout (in-place, per 64-byte chunk) +// after repack, vld1q_s8 reads produce the same values as vld4q_s8 de-interleave +void stq1_0_repack_q8_K_inplace(void * vy, int nb) { + block_q8_K * y = (block_q8_K *) vy; + for (int i = 0; i < nb; ++i) { + for (int j = 0; j < QK_K; j += 64) { + const int8x16x4_t v = vld4q_s8(y[i].qs + j); + vst1q_s8(y[i].qs + j + 0, v.val[0]); + vst1q_s8(y[i].qs + j + 16, v.val[1]); + vst1q_s8(y[i].qs + j + 32, v.val[2]); + vst1q_s8(y[i].qs + j + 48, v.val[3]); + } + } +} + +void ggml_vec_dot_stq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_stq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + float sumf = 0.0f; + + const uint8x16_t m3 = vdupq_n_u8(3); + const uint8x16_t mask_0f = vdupq_n_u8(0x0F); + const uint8_t (*sign_lut_16)[8] = (const uint8_t (*)[8]) table_b2b_0; + +#if defined(__ARM_FEATURE_DOTPROD) + // dotprod path: single 32-byte tbl lookup + native vdotq. + const uint8x16x2_t codebook2 = { { vld1q_u8(stq1_0_codebook), vld1q_u8(stq1_0_codebook + 16) } }; + #define STQ1_0_DOT(acc, sx, sy) vdotq_s32((acc), (sx), (sy)) + #define STQ1_0_LOOKUP(idx) vqtbl2q_u8(codebook2, (idx)) +#else + // ARMv8.0 NEON without dotprod: emulate vdotq_s32 (vmull_s8 + vpaddlq_s16) + // and split the 32-byte codebook lookup into two vqtbl1q_u8 calls. vqtbl1q_u8 + // returns 0 for out-of-range indices, so OR-ing the low and high halves + // (the high half is indexed by idx-16, which underflows to >=240 for idx<16) + // reproduces vqtbl2q_u8 byte-for-byte. + const uint8x16_t cb_lo = vld1q_u8(stq1_0_codebook); + const uint8x16_t cb_hi = vld1q_u8(stq1_0_codebook + 16); + const uint8x16_t v16 = vdupq_n_u8(16); + #define STQ1_0_DOT(acc, sx, sy) ggml_vdotq_s32((acc), (sx), (sy)) + #define STQ1_0_LOOKUP(idx) vorrq_u8(vqtbl1q_u8(cb_lo, (idx)), \ + vqtbl1q_u8(cb_hi, vsubq_u8((idx), v16))) +#endif + + // Each half processes 16 bytes of x.qs (32 codes), 4 bytes of x.sign, + // and 128 bytes of y.qs (4 lanes × 16 bytes × 2 wide-blocks). + // sumi0..sumi3, mask_0f, sign_lut_16, m3, STQ1_0_LOOKUP, STQ1_0_DOT are + // captured from the enclosing scope. +#define STQ1_0_DOT_HALF(QS_PTR, SIGN_PTR, YP_PTR) do { \ + const uint8x16_t packed = vld1q_u8(QS_PTR); \ + const uint8x16_t lo = vandq_u8(packed, mask_0f); \ + const uint8x16_t hi = vshrq_n_u8(packed, 4); \ + const uint8x16_t idx0 = vzip1q_u8(lo, hi); \ + const uint8x16_t idx1 = vzip2q_u8(lo, hi); \ + const uint8_t * sp = (SIGN_PTR); \ + const uint8x16_t s0 = vcombine_u8(vld1_u8(sign_lut_16[sp[0]]), \ + vld1_u8(sign_lut_16[sp[1]])); \ + const uint8x16_t s1 = vcombine_u8(vld1_u8(sign_lut_16[sp[2]]), \ + vld1_u8(sign_lut_16[sp[3]])); \ + const uint8x16_t sel_0 = STQ1_0_LOOKUP(vorrq_u8(idx0, s0)); \ + const uint8x16_t sel_1 = STQ1_0_LOOKUP(vorrq_u8(idx1, s1)); \ + const int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(sel_0, m3)); \ + const int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(sel_0, 2), m3)); \ + const int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(sel_0, 4), m3)); \ + const int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(sel_0, 6)); \ + const int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(sel_1, m3)); \ + const int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(sel_1, 2), m3)); \ + const int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(sel_1, 4), m3)); \ + const int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(sel_1, 6)); \ + const int8_t * yp = (YP_PTR); \ + sumi0 = STQ1_0_DOT(sumi0, sqx0, vld1q_s8(yp + 0)); \ + sumi1 = STQ1_0_DOT(sumi1, sqx1, vld1q_s8(yp + 16)); \ + sumi2 = STQ1_0_DOT(sumi2, sqx2, vld1q_s8(yp + 32)); \ + sumi3 = STQ1_0_DOT(sumi3, sqx3, vld1q_s8(yp + 48)); \ + sumi0 = STQ1_0_DOT(sumi0, sqx4, vld1q_s8(yp + 64)); \ + sumi1 = STQ1_0_DOT(sumi1, sqx5, vld1q_s8(yp + 80)); \ + sumi2 = STQ1_0_DOT(sumi2, sqx6, vld1q_s8(yp + 96)); \ + sumi3 = STQ1_0_DOT(sumi3, sqx7, vld1q_s8(yp + 112)); \ +} while (0) + + for (int i = 0; i < nb; ++i) { + // 4 accumulators + int32x4_t sumi0 = vdupq_n_s32(0); + int32x4_t sumi1 = vdupq_n_s32(0); + int32x4_t sumi2 = vdupq_n_s32(0); + int32x4_t sumi3 = vdupq_n_s32(0); + + STQ1_0_DOT_HALF(x[i].qs, x[i].sign, y[i].qs); + STQ1_0_DOT_HALF(x[i].qs + 16, x[i].sign + 4, y[i].qs + 128); + + const int16x8_t ysum0 = vld1q_s16(y[i].bsums); + const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + sumi0 = vaddq_s32(vaddq_s32(sumi0, sumi1), vaddq_s32(sumi2, sumi3)); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); + } + +#undef STQ1_0_DOT +#undef STQ1_0_LOOKUP +#undef STQ1_0_DOT_HALF + + *s = sumf; +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_stq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8b7acafdaa8..00e1bf990f4 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -396,6 +396,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_STQ1_0] = { + .from_float = quantize_row_stq1_0, + .vec_dot = ggml_vec_dot_stq1_0_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, [GGML_TYPE_I32] = { .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32, }, @@ -1337,6 +1343,17 @@ UseGgmlGemm1:; from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0), (ne10_block_end - ne10_block_start) * bs); + +#if defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) + // STQ1_0 NEON vec_dot wants Q8_K activations in planar + // layout (vld1q reads); deinterleave once here so the + // per-row dot loop amortizes it across all M weight rows. + if (src0->type == GGML_TYPE_STQ1_0) { + stq1_0_repack_q8_K_inplace( + (void *)(wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0), + (int)(ne10_block_end - ne10_block_start)); + } +#endif } } } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6bc8dc150ce..8572a7a1758 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -679,6 +679,7 @@ void ggml_compute_forward_add( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_STQ1_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -1130,6 +1131,7 @@ void ggml_compute_forward_add1( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_STQ1_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -1260,6 +1262,7 @@ void ggml_compute_forward_acc( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_STQ1_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -4395,6 +4398,7 @@ void ggml_compute_forward_out_prod( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_STQ1_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -4672,6 +4676,7 @@ void ggml_compute_forward_set( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_STQ1_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -4896,6 +4901,7 @@ void ggml_compute_forward_get_rows( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_STQ1_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -5622,6 +5628,7 @@ void ggml_compute_forward_clamp( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_STQ1_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index e5f9a4083f9..672e5cfd2bd 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -112,6 +112,12 @@ void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, quantize_row_tq2_0_ref(x, y, k); } +void quantize_row_stq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_stq1_0 * GGML_RESTRICT y = vy; + quantize_row_stq1_0_ref(x, y, k); +} + //===================================== Q8_K ============================================== void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { @@ -511,6 +517,38 @@ void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, *s = sumf; } +void ggml_vec_dot_stq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_stq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (int g = 0; g < QK_K/4; ++g) { + const uint8_t code = (x[i].qs[g/2] >> (4 * (g & 1))) & 0x0F; + const uint8_t sign = (x[i].sign[g/8] >> (g % 8)) & 0x01; + const uint8_t qpack = stq1_0_codebook[((uint32_t) sign << 4) | code]; + + for (int p = 0; p < 4; ++p) { + const int q = (qpack >> (2*p)) & 0x3; + sumi += (q - 1) * y[i].qs[g*4 + p]; + } + } + + sumf += (float) sumi * (GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d); + } + + *s = sumf; +} + void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index d4bc87a1c05..370abe6885f 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -32,6 +32,7 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_stq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -55,6 +56,12 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_stq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +#if defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) +// STQ1_0 ARM NEON path uses planar Q8_K activation; the in-place repack +// is only declared on ARM and called from ggml_compute_forward_mul_mat. +void stq1_0_repack_q8_K_inplace(void * vy, int nb); +#endif void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -82,6 +89,7 @@ void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_stq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 10817505d9f..6e879868614 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1086,6 +1086,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_S; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; // 256 + static constexpr int qi = 32; // 32 calls per block, 8 elements each + static constexpr int vdr = 1; +}; + ////////////////////// struct ggml_cuda_device_info { diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 61630a35a29..7ce94c71163 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -433,6 +433,44 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_ } } +template +static __global__ void dequantize_block_stq1_0(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const int64_t i = blockIdx.x; + const block_stq1_0 * x = (const block_stq1_0 *) vx + i; + + const int64_t tid = threadIdx.x; // 0..31 + dst_t * y = yy + i * QK_K + tid * 8; + + const float d = __half2float(x->d); + + const uint8_t qs_val = x->qs[tid]; + const int sign_byte_idx = tid >> 2; + const int sign_bit_shift1 = (tid & 3) << 1; + const int sign_bit_shift2 = sign_bit_shift1 | 1; + + const uint8_t sign_byte = x->sign[sign_byte_idx]; + + // For group 1: g = 2*tid + const uint8_t code1 = qs_val & 0x0F; + const uint8_t sign1 = (sign_byte >> sign_bit_shift1) & 0x01; + const uint8_t qpack1 = stq1_0_codebook[((uint32_t) sign1 << 4) | code1]; + + for (int p = 0; p < 4; ++p) { + const int q = (qpack1 >> (2 * p)) & 0x3; + y[p] = ggml_cuda_cast((float) (q - 1) * d); + } + + // For group 2: g = 2*tid + 1 + const uint8_t code2 = qs_val >> 4; + const uint8_t sign2 = (sign_byte >> sign_bit_shift2) & 0x01; + const uint8_t qpack2 = stq1_0_codebook[((uint32_t) sign2 << 4) | code2]; + + for (int p = 0; p < 4; ++p) { + const int q = (qpack2 >> (2 * p)) & 0x3; + y[p + 4] = ggml_cuda_cast((float) (q - 1) * d); + } +} + template static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -593,6 +631,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq1_s<<>>(vx, y); } +template +static void dequantize_row_stq1_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_stq1_0<<>>(vx, y); +} + template static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; @@ -748,6 +792,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq1_s_cuda; case GGML_TYPE_IQ1_M: return dequantize_row_iq1_m_cuda; + case GGML_TYPE_STQ1_0: + return dequantize_row_stq1_0_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: @@ -803,6 +849,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq1_s_cuda; case GGML_TYPE_IQ1_M: return dequantize_row_iq1_m_cuda; + case GGML_TYPE_STQ1_0: + return dequantize_row_stq1_0_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 4df1b930882..e992f6c7183 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2354,7 +2354,7 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) { ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src; - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; // fusion is not universally faster on Pascal @@ -2396,12 +2396,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; bool use_mul_mat_f = !ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear + bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - bool any_gpus_with_slow_fp16 = false; if (split) { @@ -4993,6 +4992,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_STQ1_0: case GGML_TYPE_BF16: return true; default: diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index da48f313a38..02674170024 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -30,6 +30,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1; case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1; case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1; + case GGML_TYPE_STQ1_0: return vec_dot_stq1_0_q8_1; case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1; default: return nullptr; } @@ -57,6 +58,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ; case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ; case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ; + case GGML_TYPE_STQ1_0: return VDR_STQ1_0_Q8_1_MMVQ; // THÊM default: return 1; } } @@ -1023,6 +1025,12 @@ static void mul_mat_vec_q_switch_type( nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; + case GGML_TYPE_STQ1_0: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; default: GGML_ABORT("fatal error"); break; diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index d1741cc8d7b..495ee87805a 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1216,10 +1216,57 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const float2 ds = __half22float2(bq8_1[iqs].ds); return d1q * (ds.x*sumi + ds.y*delta); } +#define VDR_STQ1_0_Q8_1_MMVQ 1 +#define VDR_STQ1_0_Q8_1_MMQ 1 + +static __device__ __forceinline__ float vec_dot_stq1_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, + const int & kbx, const int & iqs) { + + const block_stq1_0 * bq = (const block_stq1_0 *) vbq + kbx; + + // iqs: 0..31 (qi=32, mỗi call xử lý 8 elements = 1 byte qs = 2 groups) + const int qb_idx = iqs >> 2; // Q8_1 sub-block index: 0..7 + const int q8_start = (iqs & 3) << 3; // offset trong Q8_1.qs: 0,8,16,24 + + const float d = __half2float(bq->d); + const float d8 = __low2float(bq8_1[qb_idx].ds); + + const uint8_t qs_byte = bq->qs[iqs]; // 1 byte = 2 group (4 element mỗi group) + + // Sign bits — mirror chính xác kernel dequantize + const int sign_byte_idx = iqs >> 2; + const int shift1 = (iqs & 3) << 1; // 0, 2, 4, 6 + const int shift2 = shift1 | 1; // 1, 3, 5, 7 + const uint8_t sb = bq->sign[sign_byte_idx]; + const uint8_t sign1 = (sb >> shift1) & 0x01; + const uint8_t sign2 = (sb >> shift2) & 0x01; + + const uint8_t code1 = qs_byte & 0x0F; + const uint8_t code2 = qs_byte >> 4; + + const uint8_t qpk1 = stq1_0_codebook[((uint32_t)sign1 << 4) | code1]; + const uint8_t qpk2 = stq1_0_codebook[((uint32_t)sign2 << 4) | code2]; + + const int8_t * q8 = bq8_1[qb_idx].qs + q8_start; + + int sumi = 0; +#pragma unroll + for (int p = 0; p < 4; ++p) { + const int stq = (int)((qpk1 >> (2*p)) & 0x3) - 1; + sumi += stq * (int)q8[p]; + } +#pragma unroll + for (int p = 0; p < 4; ++p) { + const int stq = (int)((qpk2 >> (2*p)) & 0x3) - 1; + sumi += stq * (int)q8[4 + p]; + } + + return d * d8 * (float)sumi; +} #define VDR_IQ1_M_Q8_1_MMVQ 1 #define VDR_IQ1_M_Q8_1_MMQ 1 - static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 15443aa554a..fdeeae9a77d 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -2335,6 +2335,57 @@ void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RE } } +void quantize_row_stq1_0_ref(const float * GGML_RESTRICT x, block_stq1_0 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int64_t i = 0; i < nb; ++i) { + memset(y[i].qs, 0, sizeof(y[i].qs)); + memset(y[i].sign, 0, sizeof(y[i].sign)); + + float amax = 0.0f; + for (int j = 0; j < QK_K; ++j) { + const float a = fabsf(x[j]); + if (a > amax) amax = a; + } + y[i].d = GGML_FP32_TO_FP16(amax); + + // STQ1_0 forces exactly one zero per group of 4. Pick the smallest-|x| + // lane as that zero; project the other 3 onto {-d, +d} via sign. + for (int g = 0; g < QK_K/4; ++g) { + const float * xv = x + g*4; + + int zero_pos = 0; + float min_abs = fabsf(xv[0]); + for (int p = 1; p < 4; ++p) { + const float a = fabsf(xv[p]); + if (a < min_abs) { min_abs = a; zero_pos = p; } + } + + // Per-lane bits: -1 -> 0b00, 0 -> 0b01, +1 -> 0b10 + uint8_t qpack = 0; + for (int p = 0; p < 4; ++p) { + uint8_t lane; + if (p == zero_pos) { + lane = 0x1; + } else { + lane = (xv[p] < 0.0f) ? 0x0 : 0x2; + } + qpack |= (uint8_t)(lane << (2*p)); + } + + const uint8_t code = stq1_0_qpack_to_slot[qpack]; + const uint8_t sign = stq1_0_qpack_to_sign[qpack]; + assert(code != 0xFF); + + y[i].qs [g/2] |= (uint8_t)((code & 0x0F) << (4 * (g & 1))); + y[i].sign[g/8] |= (uint8_t)(sign << (g % 8)); + } + + x += QK_K; + } +} + size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row); @@ -2349,6 +2400,13 @@ size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * row_size; } +size_t quantize_stq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = ggml_row_size(GGML_TYPE_STQ1_0, n_per_row); + quantize_row_stq1_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} + void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -2407,6 +2465,26 @@ void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_REST } } +void dequantize_row_stq1_0(const block_stq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int64_t i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int g = 0; g < QK_K/4; ++g) { + const uint8_t code = (x[i].qs[g/2] >> (4 * (g & 1))) & 0x0F; + const uint8_t sign = (x[i].sign[g/8] >> (g % 8)) & 0x01; + const uint8_t qpack = stq1_0_codebook[((uint32_t) sign << 4) | code]; + + for (int p = 0; p < 4; ++p) { + const int q = (qpack >> (2*p)) & 0x3; + *y++ = (float) (q - 1) * d; + } + } + } +} + // ====================== "True" 2-bit (de)-quantization void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { @@ -5428,6 +5506,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb); } break; + case GGML_TYPE_STQ1_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_stq1_0, data, nb); + } break; case GGML_TYPE_IQ1_S: { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index d56c86da890..977f81b569e 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -34,6 +34,7 @@ GGML_API void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_API void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_stq1_0_ref(const float * GGML_RESTRICT x, block_stq1_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); @@ -62,6 +63,7 @@ GGML_API void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GG GGML_API void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_stq1_0(const block_stq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -86,6 +88,7 @@ GGML_API size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RE GGML_API size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_stq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 191cf2fa106..18265c7c89b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -927,6 +927,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .type_size = 0, .is_quantized = false, }, + [GGML_TYPE_STQ1_0] = { + .type_name = "stq1_0", + .blck_size = QK_K, + .type_size = sizeof(block_stq1_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_stq1_0, + .from_float_ref = (ggml_from_float_t) quantize_row_stq1_0_ref, + }, }; const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) { @@ -7703,6 +7711,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q6_K: result = quantize_q6_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_TQ1_0: result = quantize_tq1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_TQ2_0: result = quantize_tq2_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_STQ1_0: result = quantize_stq1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 308ebe1f4a1..801135d45a6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -4121,6 +4121,7 @@ class GGMLQuantizationType(IntEnum): MXFP4 = 39 NVFP4 = 40 Q1_0 = 41 + STQ1_0 = 42 class ExpertGatingFuncType(IntEnum): @@ -4175,6 +4176,7 @@ class LlamaFileType(IntEnum): MOSTLY_MXFP4_MOE = 38 # except 1d tensors MOSTLY_NVFP4 = 39 # except 1d tensors MOSTLY_Q1_0 = 40 # except 1d tensors + MOSTLY_STQ1_0 = 41 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -4295,6 +4297,7 @@ class VisionProjectorType: GGMLQuantizationType.MXFP4: (32, 1 + 16), GGMLQuantizationType.NVFP4: (64, 4 + 32), GGMLQuantizationType.Q1_0: (128, 2 + 16), + GGMLQuantizationType.STQ1_0: (256, 42), } diff --git a/include/llama.h b/include/llama.h index 2ea226726ad..cbdc09c32cc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -155,6 +155,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors LLAMA_FTYPE_MOSTLY_NVFP4 = 39, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q1_0 = 40, // except 1d tensors + LLAMA_FTYPE_MOSTLY_STQ1_0 = 41, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 4e65a45a50d..17c13a6152a 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -56,6 +56,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary"; case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_STQ1_0: return "STQ1_0 - 1.31 bpw ternary"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; @@ -749,6 +750,7 @@ llama_model_loader::llama_model_loader( case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; case GGML_TYPE_TQ1_0: ftype = LLAMA_FTYPE_MOSTLY_TQ1_0; break; case GGML_TYPE_TQ2_0: ftype = LLAMA_FTYPE_MOSTLY_TQ2_0; break; + case GGML_TYPE_STQ1_0: ftype = LLAMA_FTYPE_MOSTLY_STQ1_0; break; case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_S; break; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 43e05c3d56f..fff9c6782fa 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -383,7 +383,8 @@ static ggml_type tensor_type_fallback(quantize_state_impl & qs, const ggml_tenso case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: return_type = GGML_TYPE_Q4_0; break; + case GGML_TYPE_TQ2_0: + case GGML_TYPE_STQ1_0: return_type = GGML_TYPE_Q4_0; break; case GGML_TYPE_Q4_K: return_type = GGML_TYPE_Q5_0; break; case GGML_TYPE_Q5_K: return_type = GGML_TYPE_Q5_1; break; case GGML_TYPE_Q6_K: return_type = GGML_TYPE_Q8_0; break; @@ -480,7 +481,7 @@ static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = GGML_TYPE_IQ3_S; } - else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) { + else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0 || ftype == LLAMA_FTYPE_MOSTLY_STQ1_0) { new_type = GGML_TYPE_Q4_K; } } @@ -817,6 +818,7 @@ ggml_type llama_ftype_get_default_type(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q6_K: return GGML_TYPE_Q6_K; case LLAMA_FTYPE_MOSTLY_TQ1_0: return GGML_TYPE_TQ1_0; case LLAMA_FTYPE_MOSTLY_TQ2_0: return GGML_TYPE_TQ2_0; + case LLAMA_FTYPE_MOSTLY_STQ1_0: return GGML_TYPE_STQ1_0; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return GGML_TYPE_IQ2_XXS; case LLAMA_FTYPE_MOSTLY_IQ2_XS: return GGML_TYPE_IQ2_XS; case LLAMA_FTYPE_MOSTLY_IQ2_S: return GGML_TYPE_IQ2_XS; diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 3d33d47d98b..b54a4773619 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -46,6 +46,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "TQ1_0", LLAMA_FTYPE_MOSTLY_TQ1_0, " 1.69 bpw ternarization", }, { "TQ2_0", LLAMA_FTYPE_MOSTLY_TQ2_0, " 2.06 bpw ternarization", }, + { "STQ1_0", LLAMA_FTYPE_MOSTLY_STQ1_0, " 1.31 bpw ternarization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },