diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h index 6feeff1a9..c02b3f9eb 100644 --- a/ggml/src/iqk/iqk_common.h +++ b/ggml/src/iqk/iqk_common.h @@ -96,6 +96,7 @@ struct DataInfo { int ne11; const mmid_row_mapping * row_mapping = nullptr; size_t bs2 = 0; + bool debug = false; inline const char * src1_row(int iy) const { if (!row_mapping) return cy + (cur_y + iy)*by; diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 0529128cc..4cc780b51 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -366,10 +366,439 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array +static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + constexpr int k_acc = nrc_y == 1 ? 4 : 4*nrc_y; + + float32x4_t accd[k_acc]; + const float * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + const float d = *dptr * 31.75f * 1.05f; + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + // Initialize accumulators to zero + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + for (int ib = 0; ib < QK_K/64; ++ib) { + // Create scale vectors (4 elements each instead of 8) + float32x4_t scale1 = vdupq_n_f32(iq4k_values[x[i].scales[ib] & 0xf]); + float32x4_t scale2 = vdupq_n_f32(iq4k_values[x[i].scales[ib] >> 4]); + + for (int j = 0; j < 4; ++j) { + uint32_t val1 = ql[4*ib+j+ 0] + 4096; + uint32_t val2 = ql[4*ib+j+16] + 4096; + + // Generate trellis values using 3INST + const uint32x4x2_t trellis_vals1 = trellis.next8(val1); + const uint32x4x2_t trellis_vals2 = trellis.next8(val2); + const float32x4x2_t x_vals1 = trellis_gen8(trellis_vals1); + const float32x4x2_t x_vals2 = trellis_gen8(trellis_vals2); + + // Scale the trellis values + const float32x4_t x_val1_lo = vmulq_f32(scale1, x_vals1.val[0]); + const float32x4_t x_val1_hi = vmulq_f32(scale1, x_vals1.val[1]); + const float32x4_t x_val2_lo = vmulq_f32(scale2, x_vals2.val[0]); + const float32x4_t x_val2_hi = vmulq_f32(scale2, x_vals2.val[1]); + + if constexpr (nrc_y == 1) { + float32x4_t y1_lo = vld1q_f32(y[0] + i*QK_K + 32*ib + 8*j); + float32x4_t y1_hi = vld1q_f32(y[0] + i*QK_K + 32*ib + 8*j + 4); + float32x4_t y2_lo = vld1q_f32(y[0] + i*QK_K + 32*ib + 8*j + 128); + float32x4_t y2_hi = vld1q_f32(y[0] + i*QK_K + 32*ib + 8*j + 128 + 4); + + accd[0] = vfmaq_f32(accd[0], y1_lo, x_val1_lo); + accd[1] = vfmaq_f32(accd[1], y1_hi, x_val1_hi); + accd[2] = vfmaq_f32(accd[2], y2_lo, x_val2_lo); + accd[3] = vfmaq_f32(accd[3], y2_hi, x_val2_hi); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + float32x4_t y1_lo = vld1q_f32(y[iy] + i*QK_K + 32*ib + 8*j); + float32x4_t y1_hi = vld1q_f32(y[iy] + i*QK_K + 32*ib + 8*j + 4); + float32x4_t y2_lo = vld1q_f32(y[iy] + i*QK_K + 32*ib + 8*j + 128); + float32x4_t y2_hi = vld1q_f32(y[iy] + i*QK_K + 32*ib + 8*j + 128 + 4); + + accd[4*iy + 0] = vfmaq_f32(accd[4*iy + 0], y1_lo, x_val1_lo); + accd[4*iy + 1] = vfmaq_f32(accd[4*iy + 1], y1_hi, x_val1_hi); + accd[4*iy + 2] = vfmaq_f32(accd[4*iy + 2], y2_lo, x_val2_lo); + accd[4*iy + 3] = vfmaq_f32(accd[4*iy + 3], y2_hi, x_val2_hi); + } + } + } + } + } + + if constexpr (nrc_y == 1) { + float32x4_t sum_lo = vaddq_f32(accd[0], accd[2]); + float32x4_t sum_hi = vaddq_f32(accd[1], accd[3]); + float32x4_t res_lo = vmulq_n_f32(sum_lo, d); + float32x4_t res_hi = vmulq_n_f32(sum_hi, d); + info.store(ix, 0, hsum_float_8(res_lo, res_hi)); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + float32x4_t sum_lo = vaddq_f32(accd[4*iy + 0], accd[4*iy + 2]); + float32x4_t sum_hi = vaddq_f32(accd[4*iy + 1], accd[4*iy + 3]); + float32x4_t res_lo = vmulq_n_f32(sum_lo, d); + float32x4_t res_hi = vmulq_n_f32(sum_hi, d); + info.store(ix, iy, hsum_float_8(res_lo, res_hi)); + } + } + } +} + +static inline float32x4_t abs_ps(float32x4_t vals) { + // Clear sign-bit of all the 32-bit floats in vals + uint32x4_t sign_mask = vdupq_n_u32(0x80000000); + return vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vals), sign_mask)); +} + +// Helper to conditionally negate 4 floats based on 4 bits from condition_mask +static inline float32x4_t conditional_negate_ps(float32x4_t vals, uint32_t condition_bits) { + // Create masks for each lane based on individual bits + uint32_t masks[4] = { + (condition_bits & 0x00000001) ? 0x80000000 : 0, + (condition_bits & 0x00000100) ? 0x80000000 : 0, + (condition_bits & 0x00010000) ? 0x80000000 : 0, + (condition_bits & 0x01000000) ? 0x80000000 : 0 + }; + + uint32x4_t xor_mask = vld1q_u32(masks); + return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(vals), xor_mask)); +} + +template +static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + // Need 2 accumulators per y row for ARM Neon (4 floats each) + float32x4_t accd[nrc_y * 2]; + const float * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + const float d = *dptr * 31.75f * 1.015f; + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + // Initialize accumulators to zero + for (int iy = 0; iy < nrc_y * 2; ++iy) accd[iy] = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + const uint8_t * qh = x[i].qh; + + for (int j = 0; j < 128; j+=8) { + uint64_t mask1 = 0x0101010101010101ULL << (j/32); + uint64_t mask2 = mask1 << 4; + uint32_t val1 = ql[j/8] + 4096; + uint32_t val2 = ql[j/8+16] + 4096; + const uint64_t signs = *((const uint64_t *)(qh + (j%32))); + const float x_scale1 = (x[i].scales[j/32] & 0xf); + const float x_scale2 = (x[i].scales[j/32] >> 4); + // Generate abs(trellis values) with 3INST + const uint32x4x2_t trellis_vals1 = trellis.next8(val1); + const uint32x4x2_t trellis_vals2 = trellis.next8(val2); + const float32x4x2_t gen_vals1 = trellis_gen8(trellis_vals1); + const float32x4x2_t gen_vals2 = trellis_gen8(trellis_vals2); + const float32x4_t x_val1_lo = abs_ps(gen_vals1.val[0]); + const float32x4_t x_val1_hi = abs_ps(gen_vals1.val[1]); + const float32x4_t x_val2_lo = abs_ps(gen_vals2.val[0]); + const float32x4_t x_val2_hi = abs_ps(gen_vals2.val[1]); + // Scale the values + const float32x4_t scale1 = vdupq_n_f32(x_scale1); + const float32x4_t scale2 = vdupq_n_f32(x_scale2); + const float32x4_t scaled_x1_lo = vmulq_f32(scale1, x_val1_lo); + const float32x4_t scaled_x1_hi = vmulq_f32(scale1, x_val1_hi); + const float32x4_t scaled_x2_lo = vmulq_f32(scale2, x_val2_lo); + const float32x4_t scaled_x2_hi = vmulq_f32(scale2, x_val2_hi); + // Extract sign bits + uint64_t signs_mask1 = signs & mask1; + uint64_t signs_mask2 = signs & mask2; + uint64_t sign_bits1 = signs_mask1 >> (j/32); + uint64_t sign_bits2 = signs_mask2 >> (j/32+4); + + for (int iy = 0; iy < nrc_y; ++iy) { + float32x4_t y1_lo = vld1q_f32(y[iy] + i*QK_K + j); + float32x4_t y1_hi = vld1q_f32(y[iy] + i*QK_K + j + 4); + float32x4_t y2_lo = vld1q_f32(y[iy] + i*QK_K + j + 128); + float32x4_t y2_hi = vld1q_f32(y[iy] + i*QK_K + j + 128 + 4); + + y1_lo = conditional_negate_ps(y1_lo, sign_bits1 & 0xFFFFFFFF); + y1_hi = conditional_negate_ps(y1_hi, sign_bits1 >> 32); + y2_lo = conditional_negate_ps(y2_lo, sign_bits2 & 0xFFFFFFFF); + y2_hi = conditional_negate_ps(y2_hi, sign_bits2 >> 32); + + accd[iy*2 + 0] = vfmaq_f32(accd[iy*2 + 0], y1_lo, scaled_x1_lo); + accd[iy*2 + 1] = vfmaq_f32(accd[iy*2 + 1], y1_hi, scaled_x1_hi); + accd[iy*2 + 0] = vfmaq_f32(accd[iy*2 + 0], y2_lo, scaled_x2_lo); + accd[iy*2 + 1] = vfmaq_f32(accd[iy*2 + 1], y2_hi, scaled_x2_hi); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + // Sum the two accumulators for this y row + float32x4_t sum = vaddq_f32(accd[iy*2], accd[iy*2 + 1]); + float32x4_t res = vmulq_n_f32(sum, d); + info.store(ix, iy, hsum_float_4(res)); + } + } +} + +template +static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis2 trellis; + + float32x4_t accd[nrc_y * 2]; + float32x4_t accd2[nrc_y * 2]; + const float * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + const float d = dptr[0] * 31.75f * 1.01f; + const float row_av = dptr[1]; + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + for (int iy = 0; iy < nrc_y * 2; ++iy) { + accd[iy] = vdupq_n_f32(0.0f); + accd2[iy] = vdupq_n_f32(0.0f); + } + + for (int i = 0; i < nb; ++i) { + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + + for (int j = 0; j < 128; j+=8) { + const uint32_t offset1 = 4096 + ((shb[j/32+0] & 1) << 15); + const uint32_t offset2 = 4096 + ((shb[j/32+4] & 1) << 15); + const float x_scale1 = (int)((shb[j/32+0] & 0xff) >> 1) - 64; + const float x_scale2 = (int)((shb[j/32+4] & 0xff) >> 1) - 64; + const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4)); + const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4)); + + uint32_t val1 = ql[j/4+ 0] + ((qh[j/4+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; + uint32_t val2 = ql[j/4+32] + ((qh[j/4+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; + uint32_t val3 = ql[j/4+ 1] + ((qh[j/4+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; + uint32_t val4 = ql[j/4+33] + ((qh[j/4+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; + + // Generate trellis values + const uint32x4x2_t trellis_vals1 = trellis.next8(val1, val3); + const uint32x4x2_t trellis_vals2 = trellis.next8(val2, val4); + const float32x4x2_t x_vals1 = trellis_gen8(trellis_vals1); + const float32x4x2_t x_vals2 = trellis_gen8(trellis_vals2); + + // Scale the values + const float32x4_t scale1 = vdupq_n_f32(x_scale1); + const float32x4_t scale2 = vdupq_n_f32(x_scale2); + const float32x4_t x_val1_lo = vmulq_f32(scale1, x_vals1.val[0]); + const float32x4_t x_val1_hi = vmulq_f32(scale1, x_vals1.val[1]); + const float32x4_t x_val2_lo = vmulq_f32(scale2, x_vals2.val[0]); + const float32x4_t x_val2_hi = vmulq_f32(scale2, x_vals2.val[1]); + + for (int iy = 0; iy < nrc_y; ++iy) { + float32x4_t y1_lo = vld1q_f32(y[iy] + i*QK_K + j); + float32x4_t y1_hi = vld1q_f32(y[iy] + i*QK_K + j + 4); + float32x4_t y2_lo = vld1q_f32(y[iy] + i*QK_K + j + 128); + float32x4_t y2_hi = vld1q_f32(y[iy] + i*QK_K + j + 128 + 4); + + accd[iy*2 + 0] = vfmaq_f32(accd[iy*2 + 0], y1_lo, x_val1_lo); + accd[iy*2 + 1] = vfmaq_f32(accd[iy*2 + 1], y1_hi, x_val1_hi); + accd[iy*2 + 0] = vfmaq_f32(accd[iy*2 + 0], y2_lo, x_val2_lo); + accd[iy*2 + 1] = vfmaq_f32(accd[iy*2 + 1], y2_hi, x_val2_hi); + + accd2[iy*2 + 0] = vaddq_f32(accd2[iy*2 + 0], y1_lo); + accd2[iy*2 + 1] = vaddq_f32(accd2[iy*2 + 1], y1_hi); + accd2[iy*2 + 0] = vaddq_f32(accd2[iy*2 + 0], y2_lo); + accd2[iy*2 + 1] = vaddq_f32(accd2[iy*2 + 1], y2_hi); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + // Sum the two accumulators for this y row + float32x4_t sum1 = vaddq_f32(accd[iy*2], accd[iy*2 + 1]); + float32x4_t sum2 = vaddq_f32(accd2[iy*2], accd2[iy*2 + 1]); + + // Scale by d and row_av + float32x4_t res1 = vmulq_n_f32(sum1, d); + float32x4_t res2 = vmulq_n_f32(sum2, row_av); + + // Compute final result + float result = hsum_float_4(res1) + hsum_float_4(res2); + info.store(ix, iy, result); + } + } +} + +} // namespace bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { - return false; + if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F32) { + return false; + } + + func16 = nullptr; + + switch (typeA) { + case GGML_TYPE_IQ2_KT: + assert (ne00 % QK_K == 0); + kernels[0] = mul_mat_iq2_kt_F32_T<1>; + kernels[1] = mul_mat_iq2_kt_F32_T<2>; + kernels[2] = mul_mat_iq2_kt_F32_T<3>; + kernels[3] = mul_mat_iq2_kt_F32_T<4>; + kernels[4] = mul_mat_iq2_kt_F32_T<5>; + kernels[5] = mul_mat_iq2_kt_F32_T<6>; + kernels[6] = mul_mat_iq2_kt_F32_T<7>; + kernels[7] = mul_mat_iq2_kt_F32_T<8>; + break; + case GGML_TYPE_IQ3_KT: + assert (ne00 % QK_K == 0); + kernels[0] = mul_mat_iq3_kt_F32_T<1>; + kernels[1] = mul_mat_iq3_kt_F32_T<2>; + kernels[2] = mul_mat_iq3_kt_F32_T<3>; + kernels[3] = mul_mat_iq3_kt_F32_T<4>; + kernels[4] = mul_mat_iq3_kt_F32_T<5>; + kernels[5] = mul_mat_iq3_kt_F32_T<6>; + kernels[6] = mul_mat_iq3_kt_F32_T<7>; + kernels[7] = mul_mat_iq3_kt_F32_T<8>; + break; + case GGML_TYPE_IQ4_KT: + assert (ne00 % QK_K == 0); + kernels[0] = mul_mat_iq4_kt_F32_T<1>; + kernels[1] = mul_mat_iq4_kt_F32_T<2>; + kernels[2] = mul_mat_iq4_kt_F32_T<3>; + kernels[3] = mul_mat_iq4_kt_F32_T<4>; + kernels[4] = mul_mat_iq4_kt_F32_T<5>; + kernels[5] = mul_mat_iq4_kt_F32_T<6>; + kernels[6] = mul_mat_iq4_kt_F32_T<7>; + kernels[7] = mul_mat_iq4_kt_F32_T<8>; + break; + default: + return false; + } + + return true; } #endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 43be08857..6e7038d57 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -310,7 +310,6 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long stride_C, int ith, int nth) { - MulMat mm; if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { return false; @@ -620,6 +619,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ5_KS_R4: return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, m.funcs, m.func16); + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: + return iqk_set_kernels_ktquants(ne00, typeA, typeB, m.funcs, m.func16); case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: