From 1b41d792ec202d8ebbd7b0cff7aeccb4fecc8945 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 5 Aug 2024 14:22:05 +0300 Subject: [PATCH 01/11] iq2_tn: TriLM specific 2.0625 bpw quantization Quantize/dequantize/scale dot product. I get 46 t/s for the TriLM-3.9B with any SIMD! Finally a compiler doing a decent job auto-vectorizing the scalar implementation. --- examples/quantize/quantize.cpp | 1 + ggml/include/ggml.h | 2 + ggml/src/ggml-common.h | 12 +++- ggml/src/ggml-quants.c | 1 + ggml/src/ggml.c | 21 +++++++ ggml/src/iqk/iqk_quantize.cpp | 107 +++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_quantize.h | 6 ++ include/llama.h | 1 + src/llama.cpp | 9 ++- 9 files changed, 157 insertions(+), 3 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index bae071ce..5c311e3b 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -28,6 +28,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", }, { "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", }, + { "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.06 bpw quantization (TriLM)", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 94ffae7e..144e87f5 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -393,6 +393,7 @@ extern "C" { GGML_TYPE_IQ3_K = 38, GGML_TYPE_IQ4_K = 39, GGML_TYPE_IQ5_K = 40, + GGML_TYPE_IQ2_TN = 41, GGML_TYPE_COUNT, }; @@ -443,6 +444,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ3_K = 31, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_K = 32, // except 1d tensors GGML_FTYPE_MOSTLY_IQ5_K = 33, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_TN = 34, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 423797b6..5847d903 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -407,7 +407,7 @@ typedef struct { static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding"); // -// Bitnet - implemented as 1.75 bpw +// Bitnet - implemented as 1.625 bpw // The block scale is a waste, but it allows us to plug it in without any additional // changes to ggml. // @@ -418,13 +418,21 @@ typedef struct { } block_iq1_bn; static_assert(sizeof(block_iq1_bn) == 13, "wrong iq1_bn block size/padding"); // -// Bitnet - implemented as 2.25 bpw +// Bitnet - implemented as 2.0 bpw // #define QK_IQ2BN 64 typedef struct { uint8_t qs[QK_IQ2BN/4]; } block_iq2_bn; static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/padding"); +// +// TriLM - implemented as 2.0625 bpw +// +typedef struct { + ggml_half d; + uint8_t qs[QK_K/4]; +} block_iq2_tn; +static_assert(sizeof(block_iq2_tn) == sizeof(ggml_half) + QK_K/4, "wrong iqt_bn block size/padding"); // Used by IQ1_M quants typedef union { diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 415249fb..9b3fddbc 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -14996,6 +14996,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ3_K: break; case GGML_TYPE_IQ4_K: break; case GGML_TYPE_IQ5_K: break; + case GGML_TYPE_IQ2_TN: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4ce9948d..5c817030 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -882,6 +882,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K64, .nrows = 1, }, + [GGML_TYPE_IQ2_TN] = { + .type_name = "iq2_tn", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_tn), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_tn, + .from_float = quantize_row_iq2_tn, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_tn_ref, + .vec_dot = vec_dot_iq2_tn_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, [GGML_TYPE_IQ4_NL] = { .type_name = "iq4_nl", .blck_size = QK4_NL, @@ -3375,6 +3387,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break; case GGML_FTYPE_MOSTLY_IQ1_BN: wtype = GGML_TYPE_IQ1_BN; break; case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break; + case GGML_FTYPE_MOSTLY_IQ2_TN: wtype = GGML_TYPE_IQ2_TN; break; case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break; @@ -9628,6 +9641,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -10012,6 +10026,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -10146,6 +10161,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -13069,6 +13085,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -13263,6 +13280,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -13531,6 +13549,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -14126,6 +14145,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -20865,6 +20885,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_TN: result = quantize_iq2_tn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index c840fabf..1cba1532 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -1514,3 +1514,110 @@ size_t quantize_iq5_k(const float * src, void * dst, int64_t nrows, int64_t n_pe } return nrows * nblock * sizeof(block_iq5_k); } + +// +// ========================== IQ2_TN +// + +void quantize_row_iq2_tn_ref(const float * x, block_iq2_tn * y, int64_t k) { + GGML_ASSERT(k%QK_K == 0); + + int nb = k/QK_K; + + auto quantize = [] (float xmax, float x) { + return x < -0.5f*xmax ? 0 : x < 0.5f*xmax ? 1 : 2; + }; + + for (int ibl = 0; ibl < nb; ++ibl) { + auto xb = x + QK_K*ibl; + float max = xb[0]; + for (int j = 0; j < QK_K; ++j) { + float ax = fabsf(xb[j]); + max = std::max(ax, max); + } + y[ibl].d = GGML_FP32_TO_FP16(max); + auto qs = y[ibl].qs; + for (int l = 0; l < QK_K/128; ++l) { + for (int j = 0; j < 32; ++j) { + qs[j] = quantize(max, xb[j]) | (quantize(max, xb[j+32]) << 2) | (quantize(max, xb[j+64]) << 4) | (quantize(max, xb[j+96]) << 6); + } + xb += 128; + qs += 32; + } + } +} + +void quantize_row_iq2_tn(const float * x, void * y, int64_t k) { + quantize_row_iq2_tn_ref(x, (block_iq2_tn *)y, k); +} + +size_t quantize_iq2_tn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * /*imatrix*/) { + auto row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row); + char * qrow = (char *)dst; + for (int row = 0; row < nrows; ++row) { + quantize_row_iq2_tn_ref(src, (block_iq2_tn *)qrow, n_per_row); + qrow += row_size; + src += n_per_row; + } + return row_size*nrows; +} + +void dequantize_row_iq2_tn(const block_iq2_tn * x, float * y, int64_t k) { + GGML_ASSERT(k%QK_K == 0); + int nb = k/QK_K; + for (int ibl = 0; ibl < nb; ++ibl) { + float d = GGML_FP16_TO_FP32(x[ibl].d); + auto qs = x[ibl].qs; + for (int l = 0; l < QK_K/128; ++l) { + for (int j = 0; j < 32; ++j) { + y[j+ 0] = d*((qs[j] >> 0) & 3) - d; + y[j+32] = d*((qs[j] >> 2) & 3) - d; + y[j+64] = d*((qs[j] >> 4) & 3) - d; + y[j+96] = d*((qs[j] >> 6) & 3) - d; + } + y += 128; + qs += 32; + } + } +} + +void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_TN, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } + + const int nb = n / QK_K; + + const block_iq2_tn * x = (const block_iq2_tn *)vx; + const block_q8_K * y = (const block_q8_K *)vy; + + float sumf = 0; + + for (int i = 0; i < nb; i++) { + float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + auto qs = x[i].qs; + auto q8 = y[i].qs; + int sumi1 = 0, sumi2 = 0, sumi3 = 0,sumi4 = 0; + for (int j = 0; j < QK_K/16; ++j) sumi1 -= y[i].bsums[j]; + for (int l = 0; l < QK_K/128; ++l) { + for (int j = 0; j < 32; ++j) { + sumi1 += q8[j+ 0] * (qs[j] & 0x03); + sumi2 += q8[j+32] * (qs[j] & 0x0c); + sumi3 += q8[j+64] * (qs[j] & 0x30); + sumi4 += q8[j+96] * (qs[j] & 0xc0); + } + q8 += 128; + qs += 32; + } + sumf += d * (sumi1 + 0.25f*sumi2 + 0.0625f*sumi3 + 0.015625f*sumi4); + } + *s = sumf; +} + diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 0295eb99..80a9012b 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -37,6 +37,12 @@ size_t quantize_iq5_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, void dequantize_row_iq5_k(const block_iq5_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq5_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq2_tn_ref(const float * GGML_RESTRICT x, block_iq2_tn * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_tn(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq2_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq2_tn(const block_iq2_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq2_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + #ifdef __cplusplus } #endif diff --git a/include/llama.h b/include/llama.h index 15ff915b..a5a2deb1 100644 --- a/include/llama.h +++ b/include/llama.h @@ -174,6 +174,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ3_K = 39, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_K = 40, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ5_K = 41, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_TN = 42, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index e530f528..7a28314e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3759,6 +3759,7 @@ struct llama_model_loader { case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break; case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break; + case GGML_TYPE_IQ2_TN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_TN; break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break; @@ -4471,6 +4472,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: return "Q4_0_8_8"; case LLAMA_FTYPE_MOSTLY_IQ1_BN: return "IQ1_BN - 1.625 bpw Bitnet"; case LLAMA_FTYPE_MOSTLY_IQ2_BN: return "IQ2_BN - 2.00 bpw Bitnet"; + case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQT_BN - 2.06 bpw TriLM"; default: return "unknown, may not work"; } @@ -15437,6 +15439,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN) { new_type = GGML_TYPE_IQ4_NL; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_TN) { + new_type = GGML_TYPE_Q4_K; + } else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || new_type == GGML_TYPE_Q4_0_8_8) { new_type = GGML_TYPE_Q4_0; @@ -15640,7 +15645,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || - new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K) { + new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_IQ2_TN) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -15665,6 +15670,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_TN: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_IQ2_K: @@ -15773,6 +15779,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; case LLAMA_FTYPE_MOSTLY_IQ1_BN: default_type = GGML_TYPE_IQ1_BN; break; case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break; + case LLAMA_FTYPE_MOSTLY_IQ2_TN: default_type = GGML_TYPE_IQ2_TN; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break; From dd0b08d1d84b99b4e9c256737f0ca43ec73ffa3b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 5 Aug 2024 14:53:49 +0300 Subject: [PATCH 02/11] iq2_tn: AVX512 Just reusing the k-quants template gets us to PP-512 = 376 t/s, TG-128 = 47.6 t/s for TriLM-3.9B. --- ggml/src/iqk/iqk_mul_mat.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 3a81d3ac..2ed3b324 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -692,6 +692,18 @@ struct DequantizerQ2K final : public BaseDequantizer { }; +struct DequantizerIQ2TN final : public BaseDequantizer { + DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + bits.prepare(x[i].qs); + process_mins_16(_mm256_set1_epi16(1), q8, i, -d, accm); + scales[0] = scales[1] = _mm512_set1_epi16(1); + } + Q2Bits bits; +}; + struct DequantizerQ3K final : public BaseDequantizer { DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template @@ -3156,6 +3168,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; + case GGML_TYPE_IQ2_TN: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; case GGML_TYPE_Q3_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); From c063954c1a945007a48c52eb9e8671c7b4f170f4 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 5 Aug 2024 15:17:27 +0300 Subject: [PATCH 03/11] iq2_tn: AVX512 With this tweak we get to PP-512 = 431 t/s. --- ggml/src/iqk/iqk_mul_mat.cpp | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 2ed3b324..37bf337b 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -695,11 +695,11 @@ struct DequantizerQ2K final : public BaseDequantizer { struct DequantizerIQ2TN final : public BaseDequantizer { DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template - inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); bits.prepare(x[i].qs); - process_mins_16(_mm256_set1_epi16(1), q8, i, -d, accm); - scales[0] = scales[1] = _mm512_set1_epi16(1); + //process_mins_16(_mm256_set1_epi16(1), q8, i, -d, accm); + //scales[0] = scales[1] = _mm512_set1_epi16(1); } Q2Bits bits; }; @@ -997,14 +997,22 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da deq.new_block(i, q8, accm, scales); for (int iy = 0; iy < nrc_y; ++iy) { - //compute_block(iy, i, deq.d, q8, deq.bits.values, scales, accd); - const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0)); - const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1)); - const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2)); - const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3)); - auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); - sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); - accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); + if constexpr (std::is_same_v) { + auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i)); + auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32( + _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0), + deq.bits.values[0], q8.load_quants64(iy, i, 0)), deq.bits.values[1], q8.load_quants64(iy, i, 1)), + deq.bits.values[2], q8.load_quants64(iy, i, 2)), deq.bits.values[3], q8.load_quants64(iy, i, 3)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); + } else { + const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0)); + const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1)); + const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2)); + const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3)); + auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); + sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); + } } } From d0cc103878a7f2f663c26034d6e3e8b5f49a0a1b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 5 Aug 2024 15:47:12 +0300 Subject: [PATCH 04/11] iq2_tn: AVX512 With this tweak we get TG-128 = 19.58 / 35.18 t/s for 1 / 2 threads. At 4 threads we saturate at 48.41 t/s, and then performance slowly degrades with increasing number of threads. --- ggml/src/iqk/iqk_mul_mat.cpp | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 37bf337b..758e350e 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -698,8 +698,6 @@ struct DequantizerIQ2TN final : public BaseDequantizer { inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); bits.prepare(x[i].qs); - //process_mins_16(_mm256_set1_epi16(1), q8, i, -d, accm); - //scales[0] = scales[1] = _mm512_set1_epi16(1); } Q2Bits bits; }; @@ -972,6 +970,16 @@ inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); } +template +inline void compute_block_iq2tn(int iy, int i, float d, const Q8& q8, const __m512i * values, __m512 * accd) { + auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i)); + auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32( + _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0), + values[0], q8.load_quants64(iy, i, 0)), values[1], q8.load_quants64(iy, i, 1)), + values[2], q8.load_quants64(iy, i, 2)), values[3], q8.load_quants64(iy, i, 3)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); +} + template static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); @@ -1054,19 +1062,33 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx); - for (int kx = 0; kx < k_nx; ++kx) { - compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd); + if constexpr (std::is_same_v) { + for (int kx = 0; kx < k_nx; ++kx) { + compute_block_iq2tn(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, &accd); + } + } else { + for (int kx = 0; kx < k_nx; ++kx) { + compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd); + } } } if (2*(nb/2) < nb) { int i0 = 2*(nb/2); deq[0]->new_block(i0, q8, &accm, scales); - compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd); + if constexpr (std::is_same_v) { + compute_block_iq2tn(0, i0, deq[0]->d, q8, deq[0]->bits.values, &accd); + } else { + compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd); + } } - auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1)); - info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256))); + if constexpr (std::is_same_v) { + info.store(ix, 0, _mm512_reduce_add_ps(accd)); + } else { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1)); + info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256))); + } } } From a63ba11a2565e49ce30345e9518db08ed025a5b8 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 5 Aug 2024 17:25:14 +0300 Subject: [PATCH 05/11] iq2_tn: AVX2 PP512 = 440 t/s on the Ryzen-5975WX. We should be able to do better. --- ggml/src/iqk/iqk_mul_mat.cpp | 85 ++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 758e350e..5eea36c0 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1481,6 +1481,80 @@ struct DequantizerQ6K final : public BaseDequantizer { const __m256i mh = _mm256_set1_epi8(0x30); }; +struct DequantizerIQ2TN final : public BaseDequantizer { + DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + //template + //inline void new_block(int i, const Q8& q8, __m256i * sumi) { + // d = GGML_FP16_TO_FP32(x[i].d); + // for (int iy = 0; iy < Q8::nrc_y; ++iy) { + // sumi[iy] = q8.load_bsums(iy, i); + // } + //} + inline void new_block(int i) { + d = GGML_FP16_TO_FP32(x[i].d); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + } + + Q2Bits bits; +}; + + +template +IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Q8 q8(info); + DequantizerIQ2TN deq(vx, bx); + + __m256 accd[nrc_y]; + const auto m1 = _mm256_set1_epi16(1); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + __m256i sumi[nrc_y]; + //deq.new_block(i, q8, sumi); + deq.new_block(i); + + deq.prepare(i, 0); + for (int iy = 0; iy < nrc_y; ++iy) { + //sumi[iy] = _mm256_sub_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)), + // _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1))), sumi[iy]); + sumi[iy] = _mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)), + _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1))); + sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 2)), + _mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 3))), sumi[iy]); + } + deq.prepare(i, 1); + for (int iy = 0; iy < nrc_y; ++iy) { + sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 4)), + _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 5))), sumi[iy]); + sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 6)), + _mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 7))), sumi[iy]); + sumi[iy] = _mm256_sub_epi16(sumi[iy], q8.load_bsums(iy, i)); + } + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])), accd[iy]); + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + + } +} + template static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -3200,7 +3274,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { break; case GGML_TYPE_IQ2_TN: assert (ne00 % QK_K == 0); +#ifdef HAVE_FANCY_SIMD MulMat::set_functions(mm); +#else + mm.funcs[0] = mul_mat_iq2tn_q8_K<1>; + mm.funcs[1] = mul_mat_iq2tn_q8_K<2>; + mm.funcs[2] = mul_mat_iq2tn_q8_K<3>; + mm.funcs[3] = mul_mat_iq2tn_q8_K<4>; + mm.funcs[4] = mul_mat_iq2tn_q8_K<5>; + mm.funcs[5] = mul_mat_iq2tn_q8_K<6>; + //mm.funcs[6] = mul_mat_iq2tn_q8_K<7>; + //mm.funcs[7] = mul_mat_iq2tn_q8_K<8>; +#endif break; case GGML_TYPE_Q3_K: assert (ne00 % QK_K == 0); From 810285581cefec72ac35b0aa8aefda0beab84555 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 5 Aug 2024 17:53:02 +0200 Subject: [PATCH 06/11] iq2_tn: initial NEON version --- ggml/src/iqk/iqk_mul_mat.cpp | 66 ++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5eea36c0..ffe27db3 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1484,13 +1484,6 @@ struct DequantizerQ6K final : public BaseDequantizer { struct DequantizerIQ2TN final : public BaseDequantizer { DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} - //template - //inline void new_block(int i, const Q8& q8, __m256i * sumi) { - // d = GGML_FP16_TO_FP32(x[i].d); - // for (int iy = 0; iy < Q8::nrc_y; ++iy) { - // sumi[iy] = q8.load_bsums(iy, i); - // } - //} inline void new_block(int i) { d = GGML_FP16_TO_FP32(x[i].d); } @@ -3883,6 +3876,62 @@ struct DequantizerQ2K final : public BaseDequantizer { float d; }; +struct DequantizerIQ2TN final : public BaseDequantizer { + DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return true; } + + template + inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + process_scales(i, q8, acc); + return { vdupq_n_s32(1), vdupq_n_s32(1), vdupq_n_s32(1), vdupq_n_s32(1) }; + } + + template + inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); + + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); + } + } + + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + auto m1 = vdupq_n_s8(1); + bits.b1.val[0] = vsubq_s8(bits.b1.val[0], m1); + bits.b1.val[1] = vsubq_s8(bits.b1.val[1], m1); + bits.b1.val[2] = vsubq_s8(bits.b1.val[2], m1); + bits.b1.val[3] = vsubq_s8(bits.b1.val[3], m1); + bits.b2.val[0] = vsubq_s8(bits.b2.val[0], m1); + bits.b2.val[1] = vsubq_s8(bits.b2.val[1], m1); + bits.b2.val[2] = vsubq_s8(bits.b2.val[2], m1); + bits.b2.val[3] = vsubq_s8(bits.b2.val[3], m1); + } + + Q2bits bits; + + float d; +}; + // ============================= i-quants inline int32x4x4_t make_wider_8(const int8x16_t& scales8) { @@ -5400,6 +5449,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_Q2_K: MulMat::set_functions(m); break; + case GGML_TYPE_IQ2_TN: + MulMat::set_functions(m); + break; case GGML_TYPE_Q3_K: MulMat::set_functions(m); break; From e528505fc805fe7ac609eddaac0ed60108142e8c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 6 Aug 2024 06:19:59 +0200 Subject: [PATCH 07/11] iq2_tn: NEON For TriLM-3.9B running on the M2-Max we get PP-512 = 193.5 t/s, TG-128 = 75.5 t/s. This is in line with what we have for iq2_bn ant 3.3B Bitnet. --- ggml/src/iqk/iqk_mul_mat.cpp | 211 +++++++++++++++++++++++++---------- 1 file changed, 154 insertions(+), 57 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ffe27db3..3510cbaf 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3876,62 +3876,6 @@ struct DequantizerQ2K final : public BaseDequantizer { float d; }; -struct DequantizerIQ2TN final : public BaseDequantizer { - DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 16; } - constexpr static bool should_scale_quants() { return true; } - - template - inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) { - d = GGML_FP16_TO_FP32(x[i].d); - } - - template - inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { - process_scales(i, q8, acc); - return { vdupq_n_s32(1), vdupq_n_s32(1), vdupq_n_s32(1), vdupq_n_s32(1) }; - } - - template - inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - auto q8b_1 = q8.load_quants(iy, i, 4*j+0); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), - vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); - - auto q8b_2 = q8.load_quants(iy, i, 4*j+1); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), - vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); - - auto q8b_3 = q8.load_quants(iy, i, 4*j+2); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), - vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); - - auto q8b_4 = q8.load_quants(iy, i, 4*j+3); - sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), - vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); - } - } - - inline void prepare(int i, int j) { - bits.prepare(x[i].qs+32*j); - auto m1 = vdupq_n_s8(1); - bits.b1.val[0] = vsubq_s8(bits.b1.val[0], m1); - bits.b1.val[1] = vsubq_s8(bits.b1.val[1], m1); - bits.b1.val[2] = vsubq_s8(bits.b1.val[2], m1); - bits.b1.val[3] = vsubq_s8(bits.b1.val[3], m1); - bits.b2.val[0] = vsubq_s8(bits.b2.val[0], m1); - bits.b2.val[1] = vsubq_s8(bits.b2.val[1], m1); - bits.b2.val[2] = vsubq_s8(bits.b2.val[2], m1); - bits.b2.val[3] = vsubq_s8(bits.b2.val[3], m1); - } - - Q2bits bits; - - float d; -}; - // ============================= i-quants inline int32x4x4_t make_wider_8(const int8x16_t& scales8) { @@ -4460,6 +4404,151 @@ struct DequantizerIQ3S final : public BaseDequantizer { }; +struct DequantizerIQ2TN final : public BaseDequantizer { + DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return true; } + + //template + //inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) { + // d = GGML_FP16_TO_FP32(x[i].d); + //} + + inline void new_block(int i) { + d = GGML_FP16_TO_FP32(x[i].d); + } + + template + inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); + + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); + } + } + template + inline void compute1(const Q8& q8, int i, int j, int32x4_t * sumi) { + auto q8b_1 = q8.load_quants(0, i, 4*j+0); + sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(0, i, 4*j+1); + sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + q8b_1 = q8.load_quants(0, i, 4*j+2); + sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_1.val[1]); + + q8b_2 = q8.load_quants(0, i, 4*j+3); + sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_2.val[1]); + } + + IQK_ALWAYS_INLINE void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + auto m1 = vdupq_n_s8(1); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vsubq_s8(bits.b1.val[k], m1); + bits.b2.val[k] = vsubq_s8(bits.b2.val[k], m1); + } + } + + Q2bits bits; + + float d; +}; + +template +void mul_mat_iq2tn_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + DequantizerIQ2TN deq(vx, bx, nrc_y); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + float32x4_t acc[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); + + //deq.process_scales(i, q8, acc); + deq.new_block(i); + deq.prepare(i, 0); + deq.compute(q8, i, 0, sumi); + deq.prepare(i, 1); + deq.compute(q8, i, 1, sumi); + + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} +void mul_mat_iq2tn_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<1, block_q8_K> q8(info); + + DequantizerIQ2TN deq(vx, bx, 1); + + auto m1 = vdup_n_s16(-1); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + float32x4_t acc[2] = {}; + + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[2] = {}; + deq.new_block(i); + auto bsums = q8.load_bsums(0, i); + bsums.val[0] = vaddq_s32(bsums.val[0], bsums.val[1]); + sumi[0] = vmlal_s16(sumi[0], vget_low_s16 (bsums.val[0]), m1); + sumi[1] = vmlal_s16(sumi[1], vget_high_s16(bsums.val[0]), m1); + deq.bits.prepare(deq.x[i].qs); + deq.compute1(q8, i, 0, sumi); + deq.bits.prepare(deq.x[i].qs+32); + deq.compute1(q8, i, 1, sumi); + + auto vd = vdupq_n_f32(deq.d*q8.scale(0, i)); + acc[0] = vmlaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vd); + acc[1] = vmlaq_f32(acc[1], vcvtq_f32_s32(sumi[1]), vd); + + } + + acc[0] = vaddq_f32(acc[0], acc[1]); + info.store(ix, 0, vaddvq_f32(acc[0])); + } +} + template void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -5450,7 +5539,15 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { MulMat::set_functions(m); break; case GGML_TYPE_IQ2_TN: - MulMat::set_functions(m); + //MulMat::set_functions(m); + m.funcs[0] = mul_mat_iq2tn_K_q8_K_1; + m.funcs[1] = mul_mat_iq2tn_K_q8_K_T<2>; + m.funcs[2] = mul_mat_iq2tn_K_q8_K_T<3>; + m.funcs[3] = mul_mat_iq2tn_K_q8_K_T<4>; + m.funcs[4] = mul_mat_iq2tn_K_q8_K_T<5>; + m.funcs[5] = mul_mat_iq2tn_K_q8_K_T<6>; + m.funcs[6] = mul_mat_iq2tn_K_q8_K_T<7>; + m.funcs[7] = mul_mat_iq2tn_K_q8_K_T<8>; break; case GGML_TYPE_Q3_K: MulMat::set_functions(m); From 5d02f7f4a5433168e8c5784cc46b4a3f24ccd17d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 6 Aug 2024 07:39:39 +0200 Subject: [PATCH 08/11] iq2_tn: Metal For TriLM-3.9B on a 30-core M2-Max we get PP-512 = 890 t/s, TG-128 = 98.5 t/s. --- ggml/src/ggml-metal.m | 29 +++++++- ggml/src/ggml-metal.metal | 144 +++++++++++++++++++++++++++++++++++++- 2 files changed, 170 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 48384923..d54d252c 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -88,6 +88,7 @@ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K, @@ -122,6 +123,7 @@ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32, @@ -152,6 +154,7 @@ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32, @@ -179,6 +182,7 @@ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32, @@ -206,6 +210,7 @@ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32, @@ -577,6 +582,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, get_rows_iq1_bn, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, get_rows_iq2_bn, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN, get_rows_iq2_tn, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K, get_rows_iq2_k, true); @@ -611,6 +617,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, mul_mv_iq1_bn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, mul_mv_iq2_bn_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32, mul_mv_iq2_tn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32, mul_mv_iq2_k_f32, ctx->support_simdgroup_reduction); @@ -641,6 +648,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, mul_mv_id_iq1_bn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, mul_mv_id_iq2_bn_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32, mul_mv_id_iq2_tn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32, mul_mv_id_iq2_k_f32, ctx->support_simdgroup_reduction); @@ -668,6 +676,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, mul_mm_iq1_bn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, mul_mm_iq2_bn_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32, mul_mm_iq2_tn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32, mul_mm_iq2_k_f32, ctx->support_simdgroup_mm); @@ -695,6 +704,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, mul_mm_id_iq1_bn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, mul_mm_id_iq2_bn_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32, mul_mm_id_iq2_tn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32, mul_mm_id_iq2_k_f32, ctx->support_simdgroup_mm); @@ -1728,6 +1738,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break; case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break; @@ -1904,6 +1915,12 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32].pipeline; } break; + case GGML_TYPE_IQ2_TN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32].pipeline; + } break; case GGML_TYPE_IQ4_NL: { nth0 = 4; @@ -1972,7 +1989,7 @@ static enum ggml_status ggml_metal_graph_compute( src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| - src0t == GGML_TYPE_IQ3_K) { + src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { @@ -2074,6 +2091,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break; case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break; @@ -2244,6 +2262,12 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline; } break; + case GGML_TYPE_IQ2_TN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32].pipeline; + } break; case GGML_TYPE_IQ4_NL: { nth0 = 4; @@ -2323,7 +2347,7 @@ static enum ggml_status ggml_metal_graph_compute( src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| - src0t == GGML_TYPE_IQ3_K) { + src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { @@ -2384,6 +2408,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break; case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; + case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 53b2ddb8..1366905d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3330,6 +3330,129 @@ kernel void kernel_mul_mv_q2_K_f32( kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } +void kernel_mul_mv_iq2_tn_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_tn * x = (device const block_iq2_tn *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_iq2_tn) * nb / 2; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 + + device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + + for (int ib = ix; ib < nb; ib += 4) { + + float sumy = 0.f; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy += yl[i+16]; + yl[i+24] = y4[i+96]; sumy += yl[i+24]; + } + + device const half * dh = &x[ib].d; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + + for (int row = 0; row < N_DST; row++) { + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy); + + qs += step; + dh += step; + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq2_tn_f32")]] +kernel void kernel_mul_mv_iq2_tn_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_tn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, @@ -6009,7 +6132,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg } template -void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { +void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & reg) { const float d = xb->d; const float min = xb->dmin; device const uint8_t * q = (device const uint8_t *)xb->qs; @@ -6027,6 +6150,21 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg } } +template +void dequantize_iq2_tn(device const block_iq2_tn * xb, short il, thread type4x4 & reg) { + const half d = xb->d; + device const uint8_t * q = (device const uint8_t *)xb->qs + 32*(il/8) + 16*(il&1); + + il = (il/2)%4; + + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + const half dl = d * coef; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - d; + } +} + template void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { const half d_all = xb->d; @@ -6892,6 +7030,7 @@ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q; @@ -6926,6 +7065,7 @@ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; @@ -6960,6 +7100,7 @@ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; @@ -7175,6 +7316,7 @@ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_tn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; From 2cc63386706a52d0b8b713c0ddee1e285f178a8b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 6 Aug 2024 10:06:39 +0300 Subject: [PATCH 09/11] iq2_tn: CUDA For TriLM-3.9B running on RTX-4080 we get PP-512 = 9936 t/s, TG-128 = 299.2 t/s. --- ggml/src/ggml-cuda.cu | 1 + ggml/src/ggml-cuda/common.cuh | 7 ++++++ ggml/src/ggml-cuda/convert.cu | 31 ++++++++++++++++++++++++ ggml/src/ggml-cuda/iqk_mmvq.cu | 42 +++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/iqk_mmvq.cuh | 4 ++++ ggml/src/ggml-cuda/mmvq.cu | 3 +++ 6 files changed, 88 insertions(+) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index d34aa386..a115a1b4 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2759,6 +2759,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_TN: return true; default: return false; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index fbc52aa9..c18e865a 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -655,6 +655,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI1_BN; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_K; + static constexpr int qi = QI2_K; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK4_NL; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index ed7e4bd0..47ab92f0 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -153,6 +153,27 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); } +template +static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq2_tn * x = (const block_iq2_tn *) vx; + + const int64_t tid = threadIdx.x; + const int64_t n = tid/32; + const int64_t l = tid - 32*n; + const int64_t is = 8*n + l/16; + + const uint8_t q = x[i].qs[32*n + l]; + dst_t * y = yy + i*QK_K + 128*n; + + float d = __half2float(x[i].d); + y[l+ 0] = d * ((q >> 0) & 3) - d; + y[l+32] = d * ((q >> 2) & 3) - d; + y[l+64] = d * ((q >> 4) & 3) - d; + y[l+96] = d * ((q >> 6) & 3) - d; +} + template static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -646,6 +667,12 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k dequantize_block_q2_K<<>>(vx, y); } +template +static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq2_tn<<>>(vx, y); +} + template static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; @@ -812,6 +839,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_block_cuda; case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; + case GGML_TYPE_IQ2_TN: + return dequantize_row_iq2_tn_cuda; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_cuda; case GGML_TYPE_Q4_K: @@ -871,6 +900,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_block_cuda; case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; + case GGML_TYPE_IQ2_TN: + return dequantize_row_iq2_tn_cuda; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_cuda; case GGML_TYPE_Q4_K: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index acb495d1..8def1547 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -469,6 +469,41 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1( } +#define VDR_IQ2_TN_Q8_1_MMVQ 1 +#define VDR_IQ2_TN_Q8_1_MMQ 4 + +static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq2_tn * bq2 = (const block_iq2_tn *) vbq + kbx; + + const int bq8_offset = QR2_K * (iqs / QI8_1); + + const uint16_t * q16 = (const uint16_t *)bq2->qs + 2*iqs; + int v = q16[0] | (q16[1] << 16); + + float sumf = 0; + for (int i = 0; i < QR2_K; ++ i) { + int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1); + float d8 = __low2float(bq8_1[bq8_offset + i].ds); + sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0)); + v >>= 2; + } + return __half2float(bq2->d) * sumf; + + //float sumf_d = 0; + //float sumf_m = 0; + //for (int i = 0; i < QR2_K; ++ i) { + // int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1); + // float2 d8 = __half22float2(bq8_1[bq8_offset + i].ds); + // sumf_d += d8.x * ggml_cuda_dp4a(v & 0x03030303, u, 0); + // sumf_m += d8.y; + // v >>= 2; + //} + //return __half2float(bq2->d) * (sumf_d - 0.125f * sumf_m); + +} + } // namespace void mul_mat_vec_iq2_k_q8_1_cuda( @@ -499,3 +534,10 @@ void mul_mat_vec_iq5_k_q8_1_cuda( iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +void mul_mat_vec_iq2_tn_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 9a33af0d..3dc5f41c 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -16,3 +16,7 @@ void mul_mat_vec_iq5_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); +void mul_mat_vec_iq2_tn_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 56bf3ebe..428d822f 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -426,6 +426,9 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ2_BN: mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ2_TN: + mul_mat_vec_iq2_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ4_NL: mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; From 9780ac459157c5efe9b3573b035fefcb4e56a02b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 6 Aug 2024 12:34:44 +0300 Subject: [PATCH 10/11] iq2_tn: AVX2 PP improvement We now get PP-512 = 490.73 t/s for TriLM-3.9B on the Ryzen-5975WX. We have PP-512 = 636.61 t/s for Bintnet-3B quantized with iq2_bn. Bintnet-3B is actually 3.4B, TriLM-3.9B is 3.99B, so we would expect 3.43/3.99 * 636 = 546 t/s, so it seems we still have something that is not quite optimal in iq2_tn. --- ggml/src/iqk/iqk_mul_mat.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 3510cbaf..fffb3ab2 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1510,18 +1510,13 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da deq.new_row(ix); - for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { __m256i sumi[nrc_y]; - //deq.new_block(i, q8, sumi); deq.new_block(i); deq.prepare(i, 0); for (int iy = 0; iy < nrc_y; ++iy) { - //sumi[iy] = _mm256_sub_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)), - // _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1))), sumi[iy]); sumi[iy] = _mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)), _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1))); sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 2)), @@ -1535,8 +1530,14 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da _mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 7))), sumi[iy]); sumi[iy] = _mm256_sub_epi16(sumi[iy], q8.load_bsums(iy, i)); } - for (int iy = 0; iy < nrc_y; ++iy) { - accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])), accd[iy]); + if (i > 0) { + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])), accd[iy]); + } + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_mul_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy]))); + } } } @@ -2040,7 +2041,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)), _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4)))); - accd[iy] = _mm256_add_epi32(dot, accd[iy]); + accd[iy] = i > 0 ? _mm256_add_epi32(dot, accd[iy]) : dot; #endif } } @@ -3275,7 +3276,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[2] = mul_mat_iq2tn_q8_K<3>; mm.funcs[3] = mul_mat_iq2tn_q8_K<4>; mm.funcs[4] = mul_mat_iq2tn_q8_K<5>; - mm.funcs[5] = mul_mat_iq2tn_q8_K<6>; + //mm.funcs[5] = mul_mat_iq2tn_q8_K<6>; //mm.funcs[6] = mul_mat_iq2tn_q8_K<7>; //mm.funcs[7] = mul_mat_iq2tn_q8_K<8>; #endif From 8178075f8417822fcfe78bf8758a9cf43cc31239 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 6 Aug 2024 12:08:22 +0200 Subject: [PATCH 11/11] iq2_tn: small NEON improvement For TriLM-3.9B we now get PP-512 = 206.6 t/s and TG-128 = 76.4 t/s. --- ggml/src/iqk/iqk_mul_mat.cpp | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index fffb3ab2..db83b841 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -4481,28 +4481,31 @@ void mul_mat_iq2tn_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& i Q8 q8(info); DequantizerIQ2TN deq(vx, bx, nrc_y); + float32x4_t acc[nrc_y]; for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); - float32x4_t acc[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); - for (int i = 0; i < nb; ++i) { int32x4_t sumi[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); - //deq.process_scales(i, q8, acc); deq.new_block(i); deq.prepare(i, 0); deq.compute(q8, i, 0, sumi); deq.prepare(i, 1); deq.compute(q8, i, 1, sumi); - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + if (i > 0) { + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmulq_f32(vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } } } @@ -4520,11 +4523,11 @@ void mul_mat_iq2tn_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& i DequantizerIQ2TN deq(vx, bx, 1); auto m1 = vdup_n_s16(-1); + float32x4_t acc[2]; for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); - float32x4_t acc[2] = {}; for (int i = 0; i < nb; ++i) { @@ -4540,8 +4543,13 @@ void mul_mat_iq2tn_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& i deq.compute1(q8, i, 1, sumi); auto vd = vdupq_n_f32(deq.d*q8.scale(0, i)); - acc[0] = vmlaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vd); - acc[1] = vmlaq_f32(acc[1], vcvtq_f32_s32(sumi[1]), vd); + if (i > 0) { + acc[0] = vmlaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vd); + acc[1] = vmlaq_f32(acc[1], vcvtq_f32_s32(sumi[1]), vd); + } else { + acc[0] = vmulq_f32(vcvtq_f32_s32(sumi[0]), vd); + acc[1] = vmulq_f32(vcvtq_f32_s32(sumi[1]), vd); + } }