Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions ggml/src/ggml-cpu/arch/wasm/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,78 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
*s = sumf;
}

void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_1;
const int nb = n / qk;

assert(n % qk == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);

const block_q4_1 * GGML_RESTRICT x = vx;
const block_q8_1 * GGML_RESTRICT y = vy;

float sumf = 0;

#if defined __wasm_simd128__
v128_t sumv = wasm_f32x4_splat(0.0f);
float summs = 0.0f;

for (int ib = 0; ib < nb; ++ib) {
const block_q4_1 * GGML_RESTRICT x0 = &x[ib];
const block_q8_1 * GGML_RESTRICT y0 = &y[ib];

summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);

const v128_t raw = wasm_v128_load(x0->qs);
const v128_t v0s = wasm_v128_and(raw, wasm_i8x16_splat(0x0F));
const v128_t v1s = wasm_u8x16_shr(raw, 4);

const v128_t ys_lo = wasm_v128_load(y0->qs);
const v128_t ys_hi = wasm_v128_load(y0->qs + 16);

const v128_t v0s_l = wasm_u16x8_extend_low_u8x16(v0s);
const v128_t v0s_h = wasm_u16x8_extend_high_u8x16(v0s);
const v128_t ylo_l = wasm_i16x8_extend_low_i8x16(ys_lo);
const v128_t ylo_h = wasm_i16x8_extend_high_i8x16(ys_lo);
const v128_t v1s_l = wasm_u16x8_extend_low_u8x16(v1s);
const v128_t v1s_h = wasm_u16x8_extend_high_u8x16(v1s);
const v128_t yhi_l = wasm_i16x8_extend_low_i8x16(ys_hi);
const v128_t yhi_h = wasm_i16x8_extend_high_i8x16(ys_hi);

const v128_t acc = wasm_i32x4_add(
wasm_i32x4_add(
wasm_i32x4_dot_i16x8(v0s_l, ylo_l),
wasm_i32x4_dot_i16x8(v0s_h, ylo_h)),
wasm_i32x4_add(
wasm_i32x4_dot_i16x8(v1s_l, yhi_l),
wasm_i32x4_dot_i16x8(v1s_h, yhi_h)));

sumv = wasm_f32x4_add(sumv,
wasm_f32x4_mul(
wasm_f32x4_convert_i32x4(acc),
wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d))));
}

sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;

*s = sumf;

#else
UNUSED(nb);
UNUSED(x);
UNUSED(y);
UNUSED(sumf);

ggml_vec_dot_q4_1_q8_1_generic(
n, s, bs, vx, bx, vy, by, nrc);
#endif
}

void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
Expand Down