Skip to content

Commit

Permalink
Adding IQ2_TN for use with ternary models (#13)
Browse files Browse the repository at this point in the history
* iq2_tn: TriLM specific 2.0625 bpw quantization

Quantize/dequantize/scale dot product.

I get 46 t/s for the TriLM-3.9B with any SIMD!
Finally a compiler doing a decent job auto-vectorizing the
scalar implementation.

* iq2_tn: AVX512

Just reusing the k-quants template gets us to PP-512 = 376 t/s,
TG-128 = 47.6 t/s for TriLM-3.9B.

* iq2_tn: AVX512

With this tweak we get to PP-512 = 431 t/s.

* iq2_tn: AVX512

With this tweak we get TG-128 = 19.58 / 35.18 t/s for 1 / 2 threads.
At 4 threads we saturate at 48.41 t/s, and then performance slowly
degrades with increasing number of threads.

* iq2_tn: AVX2

PP512 = 440 t/s on the Ryzen-5975WX.
We should be able to do better.

* iq2_tn: initial NEON version

* iq2_tn: NEON

For TriLM-3.9B running on the M2-Max we get PP-512 = 193.5 t/s,
TG-128 = 75.5 t/s. This is in line with what we have for
iq2_bn ant 3.3B Bitnet.

* iq2_tn: Metal

For TriLM-3.9B on a 30-core M2-Max we get PP-512 = 890 t/s,
TG-128 = 98.5 t/s.

* iq2_tn: CUDA

For TriLM-3.9B running on RTX-4080 we get PP-512 = 9936 t/s,
TG-128 = 299.2 t/s.

* iq2_tn: AVX2 PP improvement

We now get PP-512 = 490.73 t/s for TriLM-3.9B on the Ryzen-5975WX.
We have PP-512 = 636.61 t/s for Bintnet-3B quantized with iq2_bn.
Bintnet-3B is actually 3.4B, TriLM-3.9B is 3.99B, so we would
expect 3.43/3.99 * 636 = 546 t/s, so it seems we still have something
that is not quite optimal in iq2_tn.

* iq2_tn: small NEON improvement

For TriLM-3.9B we now get PP-512 = 206.6 t/s and TG-128 = 76.4 t/s.

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
  • Loading branch information
ikawrakow and Kawrakow committed Aug 7, 2024
1 parent b409c15 commit a9f302e
Show file tree
Hide file tree
Showing 18 changed files with 718 additions and 20 deletions.
1 change: 1 addition & 0 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", },
{ "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", },
{ "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.06 bpw quantization (TriLM)", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },
Expand Down
2 changes: 2 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ extern "C" {
GGML_TYPE_IQ3_K = 38,
GGML_TYPE_IQ4_K = 39,
GGML_TYPE_IQ5_K = 40,
GGML_TYPE_IQ2_TN = 41,
GGML_TYPE_COUNT,
};

Expand Down Expand Up @@ -443,6 +444,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ3_K = 31, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_K = 32, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ5_K = 33, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_TN = 34, // except 1d tensors
};

// available tensor operations:
Expand Down
12 changes: 10 additions & 2 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ typedef struct {
static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");

//
// Bitnet - implemented as 1.75 bpw
// Bitnet - implemented as 1.625 bpw
// The block scale is a waste, but it allows us to plug it in without any additional
// changes to ggml.
//
Expand All @@ -418,13 +418,21 @@ typedef struct {
} block_iq1_bn;
static_assert(sizeof(block_iq1_bn) == 13, "wrong iq1_bn block size/padding");
//
// Bitnet - implemented as 2.25 bpw
// Bitnet - implemented as 2.0 bpw
//
#define QK_IQ2BN 64
typedef struct {
uint8_t qs[QK_IQ2BN/4];
} block_iq2_bn;
static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/padding");
//
// TriLM - implemented as 2.0625 bpw
//
typedef struct {
ggml_half d;
uint8_t qs[QK_K/4];
} block_iq2_tn;
static_assert(sizeof(block_iq2_tn) == sizeof(ggml_half) + QK_K/4, "wrong iqt_bn block size/padding");

// Used by IQ1_M quants
typedef union {
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2759,6 +2759,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_TN:
return true;
default:
return false;
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> {
static constexpr int qi = QI1_BN;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_TN> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_K;
static constexpr int qi = QI2_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
static constexpr int qk = QK4_NL;
Expand Down
31 changes: 31 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,27 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
}

template<typename dst_t>
static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy) {

const int64_t i = blockIdx.x;
const block_iq2_tn * x = (const block_iq2_tn *) vx;

const int64_t tid = threadIdx.x;
const int64_t n = tid/32;
const int64_t l = tid - 32*n;
const int64_t is = 8*n + l/16;

const uint8_t q = x[i].qs[32*n + l];
dst_t * y = yy + i*QK_K + 128*n;

float d = __half2float(x[i].d);
y[l+ 0] = d * ((q >> 0) & 3) - d;
y[l+32] = d * ((q >> 2) & 3) - d;
y[l+64] = d * ((q >> 4) & 3) - d;
y[l+96] = d * ((q >> 6) & 3) - d;
}

template<typename dst_t>
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {

Expand Down Expand Up @@ -646,6 +667,12 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
Expand Down Expand Up @@ -812,6 +839,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
case GGML_TYPE_IQ2_TN:
return dequantize_row_iq2_tn_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
Expand Down Expand Up @@ -871,6 +900,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
case GGML_TYPE_IQ2_TN:
return dequantize_row_iq2_tn_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
Expand Down
42 changes: 42 additions & 0 deletions ggml/src/ggml-cuda/iqk_mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,41 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1(

}

#define VDR_IQ2_TN_Q8_1_MMVQ 1
#define VDR_IQ2_TN_Q8_1_MMQ 4

static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {

const block_iq2_tn * bq2 = (const block_iq2_tn *) vbq + kbx;

const int bq8_offset = QR2_K * (iqs / QI8_1);

const uint16_t * q16 = (const uint16_t *)bq2->qs + 2*iqs;
int v = q16[0] | (q16[1] << 16);

float sumf = 0;
for (int i = 0; i < QR2_K; ++ i) {
int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1);
float d8 = __low2float(bq8_1[bq8_offset + i].ds);
sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0));
v >>= 2;
}
return __half2float(bq2->d) * sumf;

//float sumf_d = 0;
//float sumf_m = 0;
//for (int i = 0; i < QR2_K; ++ i) {
// int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1);
// float2 d8 = __half22float2(bq8_1[bq8_offset + i].ds);
// sumf_d += d8.x * ggml_cuda_dp4a(v & 0x03030303, u, 0);
// sumf_m += d8.y;
// v >>= 2;
//}
//return __half2float(bq2->d) * (sumf_d - 0.125f * sumf_m);

}

} // namespace

void mul_mat_vec_iq2_k_q8_1_cuda(
Expand Down Expand Up @@ -499,3 +534,10 @@ void mul_mat_vec_iq5_k_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

void mul_mat_vec_iq2_tn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/iqk_mmvq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ void mul_mat_vec_iq5_k_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);

void mul_mat_vec_iq2_tn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);

3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ2_BN:
mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ2_TN:
mul_mat_vec_iq2_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
Expand Down
Loading

0 comments on commit a9f302e

Please sign in to comment.