Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq2_xxs,
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_xxs_ref,
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
#ifdef __AVX2__
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_K,
#endif
.nrows = 1,
.row_meta_size = 0,
},
Expand Down
20 changes: 18 additions & 2 deletions ggml/src/iqk/iqk_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ static inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
aux32[0] = a0 & 0x3f3f3f3f;
}

#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__)
const uint64_t keven_signs[128] = {
0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
Expand Down Expand Up @@ -207,7 +206,6 @@ const uint64_t keven_signs[128] = {
0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,
0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
};
#endif

#ifdef __AVX2__

Expand Down Expand Up @@ -525,6 +523,24 @@ struct Q4Bits {

#endif

inline void iqk_transpose_8x8(__m256 * m) {
for (int k = 0; k < 8; k += 4) {
auto t0 = _mm256_unpacklo_ps(m[k+0], m[k+1]);
auto t1 = _mm256_unpacklo_ps(m[k+2], m[k+3]);
auto t2 = _mm256_unpackhi_ps(m[k+0], m[k+1]);
auto t3 = _mm256_unpackhi_ps(m[k+2], m[k+3]);
m[k+0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
m[k+1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
m[k+2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
m[k+3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
}
for (int k = 0; k < 4; ++k) {
auto t = _mm256_set_m128(_mm256_extractf128_ps(m[k+4], 1), _mm256_extractf128_ps(m[k], 1));
m[k+0] = _mm256_set_m128(_mm256_castps256_ps128(m[k+4]), _mm256_castps256_ps128(m[k+0]));
m[k+4] = t;
}
}

#else
// ------------------------------------ __aarch64__ --------------------------------------------------

Expand Down
238 changes: 235 additions & 3 deletions ggml/src/iqk/iqk_gemm_iquants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,12 @@ struct EvenSignHelper {
const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0);
const __m256i mask = _mm256_set1_epi32(127);
const __m256i mone = _mm256_set1_epi32(1);
#else
#endif
inline void sign_value(uint32_t aux32, __m256i& value) const {
auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],
keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]);
value = _mm256_sign_epi8(value, signs);
}
#endif
};

struct SignHelper {
Expand Down Expand Up @@ -144,6 +143,35 @@ struct SignHelper {
const __m256i mone = _mm256_set1_epi8(1);
};

// for (int i = 0; i < nb; ++i) {
//
// __m256i sumi[nrc_y], all_scales;
// //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
// __m256i mins;
// float dmin = deq.new_block(i, &all_scales, mins);
// for (int iy = 0; iy < nrc_y; ++iy) {
// auto bsums = q8.load_bsums(iy, i);
// auto prod = _mm256_madd_epi16(mins, bsums);
// accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
// }
//
// for (int j = 0; j < QK_K/128; ++j) {
// deq.prepare(i, j);
// set_scales_8(&all_scales, j, scales);
// //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
// multiply_add(deq.bits, scales, j, i, q8, sumi);
// }
// for (int iy = 0; iy < nrc_y; ++iy) {
// const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
// accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
// }
// }
//
// for (int iy = 0; iy < nrc_y; ++iy) {
// info.store(ix, iy, hsum_float_8(accd[iy]));
// }
// }

struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}

Expand All @@ -165,6 +193,16 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
auto sc16 = load_scales(i);
scales[0] = MM256_SET_M128I(sc16, sc16);
}
inline void new_block_f(int i, __m256 * scales) {
auto sc16 = load_scales(i);
auto scf = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(sc16)));
auto scf_l = _mm256_castps256_ps128(scf);
auto scf_h = _mm256_extractf128_ps(scf, 1);
scales[0] = _mm256_set_m128(scf_l, scf_l);
scales[1] = _mm256_set_m128(scf_h, scf_h);
scales[2] = _mm256_mul_ps(scf, _mm256_set1_ps(-minv));
}

inline float new_block(int i, __m256i * scales, __m256i& mins) {
auto sc16 = load_scales(i);
mins = scb.shuffle(sc16);
Expand Down Expand Up @@ -729,6 +767,130 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
}
}

template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
static_assert(Dequantizer::num_blocks == 8);
const int nb = n / QK_K;
Q8<nrc_y, block_q8_2_x4> q8(info);
Dequantizer deq(vx, bx);
__m256 scales[3];
__m256 accd[nrc_y];
__m256i sumi[4];

for (int ix = 0; ix < nrc_x; ++ix) {

for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();

deq.new_row(ix);

for (int i = 0; i < nb; ++i) {

deq.new_block_f(i, scales);
for (int iy = 0; iy < nrc_y; ++iy) {
auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4)));
auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4)));
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]);
}

for (int j = 0; j < QK_K/128; ++j) {
deq.prepare(i, j);
auto& values = deq.bits.values;
for (int iy = 0; iy < nrc_y; ++iy) {
auto qs = q8.y[iy][2*i+j].qs;
#ifdef HAVE_FANCY_SIMD
sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[0], _mm256_loadu_si256((const __m256i*)qs+0));
sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[1], _mm256_loadu_si256((const __m256i*)qs+1));
sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[2], _mm256_loadu_si256((const __m256i*)qs+2));
sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[3], _mm256_loadu_si256((const __m256i*)qs+3));
#else
sumi[0] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[0], _mm256_loadu_si256((const __m256i*)qs+0)));
sumi[1] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[1], _mm256_loadu_si256((const __m256i*)qs+1)));
sumi[2] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[2], _mm256_loadu_si256((const __m256i*)qs+2)));
sumi[3] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[3], _mm256_loadu_si256((const __m256i*)qs+3)));
#endif
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1]));
sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3]));
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2]));
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16));
auto dy = _mm256_set_m128(d4, d4);
accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]);
}
}
}

for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}
}
}

template <int nrc_y>
static void mul_mat_iq2_xxs_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_K;
__m256 scales[2];
__m256 accd[nrc_y];
__m256i sumi[4];
__m256i xv[4];
EvenSignHelper esh;

for (int ix = 0; ix < nrc_x; ++ix) {

for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();

const block_iq2_xxs * x = (const block_iq2_xxs *)((const char *)vx + ix*bx);

for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d)*0.125f;
const uint16_t * a16 = x[i].qs;
auto sc16 = _mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]);
sc16 = _mm_or_si128(_mm_slli_epi16(_mm_srli_epi16(sc16, 12), 1), _mm_set1_epi16(1));
auto sc32 = _mm256_cvtepi16_epi32(sc16);
auto all_scales = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sc32));
auto all_mins = _mm256_mul_ps(all_scales, _mm256_set1_ps(-43.f));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = (const block_q8_2_x4 *)info.src1_row(iy);
auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(y[2*i+0].d + 4)));
auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(y[2*i+1].d + 4)));
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
accd[iy] = _mm256_fmadd_ps(all_mins, my, accd[iy]);
}
auto scales_l = _mm256_castps256_ps128(all_scales);
auto scales_h = _mm256_extractf128_ps(all_scales, 1);
scales[0] = _mm256_set_m128(scales_l, scales_l);
scales[1] = _mm256_set_m128(scales_h, scales_h);

for (int j = 0; j < QK_K/128; ++j) {
const uint8_t * a8 = (const uint8_t *)(a16 + 16*j);
for (int k = 0; k < 4; ++k) {
auto a8k = a8 + 8*k;
xv[k] = _mm256_set_epi64x(iq2xxs_grid[a8k[3]], iq2xxs_grid[a8k[2]], iq2xxs_grid[a8k[1]], iq2xxs_grid[a8k[0]]);
uint32_t aux32; std::memcpy(&aux32, a8k+4, sizeof(uint32_t));
esh.sign_value(aux32, xv[k]);
xv[k] = _mm256_add_epi8(xv[k], _mm256_set1_epi8(43));
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = (const block_q8_2_x4 *)info.src1_row(iy);
sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[0], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+0));
sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[1], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+1));
sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[2], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+2));
sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[3], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+3));
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1]));
sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3]));
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2]));
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)y[2*i+j].d)), 16));
auto dy = _mm256_set_m128(d4, d4);
accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]);
}
}
}

for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}
}
}

template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
Expand Down Expand Up @@ -1560,6 +1722,55 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
}
}

void iqk_convert_iq2_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_iq2_xxs * x8[8];

block_q8_0_r8 * y = (block_q8_0_r8 *)vy;

ggml_half dh[8];
uint16_t all_ls[64];
EvenSignHelper esh;

uint32_t block[8];
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_xxs *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
// TODO: simdify
for (int k = 0; k < 8; ++k) {
dh[k] = x8[k][i].d;
for (int ib32 = 0; ib32 < 8; ++ib32) {
std::memcpy(aux32, x8[k][i].qs + 4*ib32, 2*sizeof(uint32_t));
all_ls[8*ib32 + k] = (2*(aux32[1] >> 28) + 1);
auto value = _mm256_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
esh.sign_value(aux32[1], value);
_mm256_storeu_si256((__m256i *)block, value);
auto qs = (uint32_t *)y[ib32].qs;
for (int l = 0; l < 4; ++l) {
qs[8*l + k + 0] = block[l + 0];
qs[8*l + k + 32] = block[l + 4];
}
}
}
auto vd = _mm256_mul_ps(_mm256_set1_ps(0.125f), _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh)));
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
}
y += QK_K/32;
}
}
}

template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
Expand All @@ -1575,7 +1786,19 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX

bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {

if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
if (ne00%QK_K != 0) return false;

if (ggml_type(typeA) == GGML_TYPE_IQ2_XXS) {
if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ2XXS, kernels);
//IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_q8_2_IQ_N, kernels);
func16 = nullptr;
return true;
}
return false;
}

if (ggml_type(typeB) != GGML_TYPE_Q8_K) {
return false;
}

Expand Down Expand Up @@ -1629,6 +1852,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_

}

bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
switch (ggml_type(type)) {
case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
}

#else
// --------------------------------------- __aarch64__ ---------------------------------------------

Expand Down
2 changes: 2 additions & 0 deletions ggml/src/iqk/iqk_gemm_iquants.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@

bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);

bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x);

#endif
Loading