Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,7 @@ a.out.*

AGENTS.local.md
.pi/SYSTEM.md


/models
/model_zoo
4 changes: 3 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@ extern "C" {
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
GGML_TYPE_Q1_0 = 41,
GGML_TYPE_COUNT = 42,
GGML_TYPE_STQ1_0 = 42,

GGML_TYPE_COUNT = 43,
};

// precision
Expand Down
81 changes: 81 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,14 @@ typedef struct {
} block_tq2_0;
static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding");

// 1.3125 bpw
typedef struct {
uint8_t qs[QK_K/8]; // 4-bit code per group of 4
uint8_t sign[QK_K/32]; // 1-bit table select per group of 4
ggml_half d; // scale
} block_stq1_0;
static_assert(sizeof(block_stq1_0) == sizeof(ggml_half) + QK_K / 8 + QK_K / 32, "wrong stq1_0 block size/padding");

//
// Super-block quantization structures
//
Expand Down Expand Up @@ -496,6 +504,79 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_

#if defined(GGML_COMMON_IMPL)

// STQ1_0 codebook: index = (sign << 4) | slot -> packed 4-lane ternary pattern.
//
// Each STQ1_0 group has 4 lanes with exactly one 0 and three non-zero (+1/-1) of
// equal magnitude — 4 zero-positions × 2^3 signs = 32 patterns. They are split
// into 4-bit slot + 1-bit sign.
//
// Per-lane encoding (2 bits): -1 -> 0b00, 0 -> 0b01, +1 -> 0b10. A pattern
// byte is qpack = lane3<<6 | lane2<<4 | lane1<<2 | lane0.
//
// Slot (low nibble): pattern with the first non-zero lane fixed to +1.
// slot = (zero_pos << 2) | tail_bits
// zero_pos -- index of the zero lane (0..3)
// tail_bits.0 -- sign of 2nd non-zero lane (0 => +1, 1 => -1)
// tail_bits.1 -- sign of 3rd non-zero lane (0 => +1, 1 => -1)
// Sign bit (bit 4): global flip of every non-zero lane (0 unchanged).
//
// Worked example, slot 0 (zero_pos=0, tail=0b00), sign=0:
// lanes = (0, +1, +1, +1) -> qpack = 10_10_10_01 = 0xA9
// Sign=1 of the same slot flips all non-zero lanes:
// lanes = (0, -1, -1, -1) -> qpack = 00_00_00_01 = 0x01
//
// The sign=1 half is precomputed so decode is a single load:
// qpack = stq1_0_codebook[(sign << 4) | slot]
GGML_TABLE_BEGIN(uint8_t, stq1_0_codebook, 32)
// sign = 0 (first non-zero lane is +1)
0xA9, 0x89, 0x29, 0x09, 0xA6, 0x86, 0x26, 0x06,
0x9A, 0x92, 0x1A, 0x12, 0x6A, 0x62, 0x4A, 0x42,
// sign = 1 (every non-zero lane negated)
0x01, 0x21, 0x81, 0xA1, 0x04, 0x24, 0x84, 0xA4,
0x10, 0x18, 0x90, 0x98, 0x40, 0x48, 0x60, 0x68,
GGML_TABLE_END()

// Reverse maps for the encoder: qpack byte -> (slot, sign). Derived from
// stq1_0_codebook above; entries for non-codebook qpack values are 0xFF (slot)
// / 0 (sign), which the encoder asserts against.
GGML_TABLE_BEGIN(uint8_t, stq1_0_qpack_to_slot, 256)
0xFF, 0x00, 0xFF, 0xFF, 0x04, 0xFF, 0x07, 0xFF, 0xFF, 0x03, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0x08, 0xFF, 0x0B, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x09, 0xFF, 0x0A, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0x01, 0xFF, 0xFF, 0x05, 0xFF, 0x06, 0xFF, 0xFF, 0x02, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0x0C, 0xFF, 0x0F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x0D, 0xFF, 0x0E, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0x0E, 0xFF, 0x0D, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x0F, 0xFF, 0x0C, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0x02, 0xFF, 0xFF, 0x06, 0xFF, 0x05, 0xFF, 0xFF, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0x0A, 0xFF, 0x09, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x0B, 0xFF, 0x08, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0x03, 0xFF, 0xFF, 0x07, 0xFF, 0x04, 0xFF, 0xFF, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
GGML_TABLE_END()

GGML_TABLE_BEGIN(uint8_t, stq1_0_qpack_to_sign, 256)
0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
GGML_TABLE_END()

GGML_TABLE_BEGIN(uint8_t, kmask_iq2xs, 8)
1, 2, 4, 8, 16, 32, 64, 128
GGML_TABLE_END()
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-cpu/arch-fallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
#define ggml_vec_dot_q3_K_q8_K_generic ggml_vec_dot_q3_K_q8_K
#define ggml_vec_dot_q4_K_q8_K_generic ggml_vec_dot_q4_K_q8_K
Expand Down Expand Up @@ -83,6 +84,7 @@
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// quants.c
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
Expand Down Expand Up @@ -116,6 +118,7 @@
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
Expand Down Expand Up @@ -159,6 +162,7 @@
#define quantize_row_q8_K_generic quantize_row_q8_K
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
Expand Down Expand Up @@ -203,6 +207,8 @@
#elif defined(__riscv)
// quants.c
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K
// repack.cpp
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
Expand Down Expand Up @@ -246,6 +252,7 @@
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
Expand Down Expand Up @@ -296,6 +303,7 @@
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_stq1_0_q8_K_generic ggml_vec_dot_stq1_0_q8_K
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
Expand Down
122 changes: 122 additions & 0 deletions ggml/src/ggml-cpu/arch/arm/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -1608,6 +1608,128 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
#endif
}

// repack q8_K qs from contiguous to planar layout (in-place, per 64-byte chunk)
// after repack, vld1q_s8 reads produce the same values as vld4q_s8 de-interleave
void stq1_0_repack_q8_K_inplace(void * vy, int nb) {
block_q8_K * y = (block_q8_K *) vy;
for (int i = 0; i < nb; ++i) {
for (int j = 0; j < QK_K; j += 64) {
const int8x16x4_t v = vld4q_s8(y[i].qs + j);
vst1q_s8(y[i].qs + j + 0, v.val[0]);
vst1q_s8(y[i].qs + j + 16, v.val[1]);
vst1q_s8(y[i].qs + j + 32, v.val[2]);
vst1q_s8(y[i].qs + j + 48, v.val[3]);
}
}
}

void ggml_vec_dot_stq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);

const block_stq1_0 * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;

const int nb = n / QK_K;

#if defined(__ARM_NEON)
float sumf = 0.0f;

const uint8x16_t m3 = vdupq_n_u8(3);
const uint8x16_t mask_0f = vdupq_n_u8(0x0F);
const uint8_t (*sign_lut_16)[8] = (const uint8_t (*)[8]) table_b2b_0;

#if defined(__ARM_FEATURE_DOTPROD)
// dotprod path: single 32-byte tbl lookup + native vdotq.
const uint8x16x2_t codebook2 = { { vld1q_u8(stq1_0_codebook), vld1q_u8(stq1_0_codebook + 16) } };
#define STQ1_0_DOT(acc, sx, sy) vdotq_s32((acc), (sx), (sy))
#define STQ1_0_LOOKUP(idx) vqtbl2q_u8(codebook2, (idx))
#else
// ARMv8.0 NEON without dotprod: emulate vdotq_s32 (vmull_s8 + vpaddlq_s16)
// and split the 32-byte codebook lookup into two vqtbl1q_u8 calls. vqtbl1q_u8
// returns 0 for out-of-range indices, so OR-ing the low and high halves
// (the high half is indexed by idx-16, which underflows to >=240 for idx<16)
// reproduces vqtbl2q_u8 byte-for-byte.
const uint8x16_t cb_lo = vld1q_u8(stq1_0_codebook);
const uint8x16_t cb_hi = vld1q_u8(stq1_0_codebook + 16);
const uint8x16_t v16 = vdupq_n_u8(16);
#define STQ1_0_DOT(acc, sx, sy) ggml_vdotq_s32((acc), (sx), (sy))
#define STQ1_0_LOOKUP(idx) vorrq_u8(vqtbl1q_u8(cb_lo, (idx)), \
vqtbl1q_u8(cb_hi, vsubq_u8((idx), v16)))
#endif

// Each half processes 16 bytes of x.qs (32 codes), 4 bytes of x.sign,
// and 128 bytes of y.qs (4 lanes × 16 bytes × 2 wide-blocks).
// sumi0..sumi3, mask_0f, sign_lut_16, m3, STQ1_0_LOOKUP, STQ1_0_DOT are
// captured from the enclosing scope.
#define STQ1_0_DOT_HALF(QS_PTR, SIGN_PTR, YP_PTR) do { \
const uint8x16_t packed = vld1q_u8(QS_PTR); \
const uint8x16_t lo = vandq_u8(packed, mask_0f); \
const uint8x16_t hi = vshrq_n_u8(packed, 4); \
const uint8x16_t idx0 = vzip1q_u8(lo, hi); \
const uint8x16_t idx1 = vzip2q_u8(lo, hi); \
const uint8_t * sp = (SIGN_PTR); \
const uint8x16_t s0 = vcombine_u8(vld1_u8(sign_lut_16[sp[0]]), \
vld1_u8(sign_lut_16[sp[1]])); \
const uint8x16_t s1 = vcombine_u8(vld1_u8(sign_lut_16[sp[2]]), \
vld1_u8(sign_lut_16[sp[3]])); \
const uint8x16_t sel_0 = STQ1_0_LOOKUP(vorrq_u8(idx0, s0)); \
const uint8x16_t sel_1 = STQ1_0_LOOKUP(vorrq_u8(idx1, s1)); \
const int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(sel_0, m3)); \
const int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(sel_0, 2), m3)); \
const int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(sel_0, 4), m3)); \
const int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(sel_0, 6)); \
const int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(sel_1, m3)); \
const int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(sel_1, 2), m3)); \
const int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(sel_1, 4), m3)); \
const int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(sel_1, 6)); \
const int8_t * yp = (YP_PTR); \
sumi0 = STQ1_0_DOT(sumi0, sqx0, vld1q_s8(yp + 0)); \
sumi1 = STQ1_0_DOT(sumi1, sqx1, vld1q_s8(yp + 16)); \
sumi2 = STQ1_0_DOT(sumi2, sqx2, vld1q_s8(yp + 32)); \
sumi3 = STQ1_0_DOT(sumi3, sqx3, vld1q_s8(yp + 48)); \
sumi0 = STQ1_0_DOT(sumi0, sqx4, vld1q_s8(yp + 64)); \
sumi1 = STQ1_0_DOT(sumi1, sqx5, vld1q_s8(yp + 80)); \
sumi2 = STQ1_0_DOT(sumi2, sqx6, vld1q_s8(yp + 96)); \
sumi3 = STQ1_0_DOT(sumi3, sqx7, vld1q_s8(yp + 112)); \
} while (0)

for (int i = 0; i < nb; ++i) {
// 4 accumulators
int32x4_t sumi0 = vdupq_n_s32(0);
int32x4_t sumi1 = vdupq_n_s32(0);
int32x4_t sumi2 = vdupq_n_s32(0);
int32x4_t sumi3 = vdupq_n_s32(0);

STQ1_0_DOT_HALF(x[i].qs, x[i].sign, y[i].qs);
STQ1_0_DOT_HALF(x[i].qs + 16, x[i].sign + 4, y[i].qs + 128);

const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;

sumi0 = vaddq_s32(vaddq_s32(sumi0, sumi1), vaddq_s32(sumi2, sumi3));
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));

sumf += d * (float) vaddvq_s32(sumi0);
}

#undef STQ1_0_DOT
#undef STQ1_0_LOOKUP
#undef STQ1_0_DOT_HALF

*s = sumf;
#else
UNUSED(x);
UNUSED(y);
UNUSED(nb);
ggml_vec_dot_stq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}

void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
Expand Down
17 changes: 17 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
[GGML_TYPE_STQ1_0] = {
.from_float = quantize_row_stq1_0,
.vec_dot = ggml_vec_dot_stq1_0_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
[GGML_TYPE_I32] = {
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
},
Expand Down Expand Up @@ -1337,6 +1343,17 @@ UseGgmlGemm1:;
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
(ne10_block_end - ne10_block_start) * bs);

#if defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
// STQ1_0 NEON vec_dot wants Q8_K activations in planar
// layout (vld1q reads); deinterleave once here so the
// per-row dot loop amortizes it across all M weight rows.
if (src0->type == GGML_TYPE_STQ1_0) {
stq1_0_repack_q8_K_inplace(
(void *)(wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
(int)(ne10_block_end - ne10_block_start));
}
#endif
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ void ggml_compute_forward_add(
case GGML_TYPE_Q6_K:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_STQ1_0:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down Expand Up @@ -1130,6 +1131,7 @@ void ggml_compute_forward_add1(
case GGML_TYPE_Q6_K:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_STQ1_0:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down Expand Up @@ -1260,6 +1262,7 @@ void ggml_compute_forward_acc(
case GGML_TYPE_Q6_K:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_STQ1_0:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down Expand Up @@ -4395,6 +4398,7 @@ void ggml_compute_forward_out_prod(
case GGML_TYPE_Q6_K:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_STQ1_0:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down Expand Up @@ -4672,6 +4676,7 @@ void ggml_compute_forward_set(
case GGML_TYPE_Q6_K:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_STQ1_0:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down Expand Up @@ -4896,6 +4901,7 @@ void ggml_compute_forward_get_rows(
case GGML_TYPE_Q6_K:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_STQ1_0:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down Expand Up @@ -5622,6 +5628,7 @@ void ggml_compute_forward_clamp(
case GGML_TYPE_Q6_K:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_STQ1_0:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down
Loading