Skip to content

Commit 3248a35

Browse files
ikawrakowIwan Kawrakow
andauthored
Adding IQ3_KS quants (#566)
* iq3_ks: basics * iq3_ks: CUDA dequantize * iq3_ks: CUDA mmvq * iq3_ks: mmq * iq3_ks: faster mmq * iq3_ks: Zen4 * iq3_ks: AVX2 convert to q8_k_r8 This gives usPP-512 = 360 t/s. * iq3_ks: AVX2 GEMM/GEMV * iq3_ks: NEON GEMM/GEMV * iq3_ks: NEON convert to q8_k_r8 This gives us PP-512 = 164 t/s. * iq3_ks: Metal dequantize * iq3_ks: Metal gemv - pathetic performance --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 46f2e5d commit 3248a35

File tree

22 files changed

+1040
-65
lines changed

22 files changed

+1040
-65
lines changed

examples/quantize/quantize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
7676
{ "IQ2_K_R4", LLAMA_FTYPE_MOSTLY_IQ2_K_R4, "IQ2_K repacked",},
7777
{ "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",},
7878
{ "IQ2_KT", LLAMA_FTYPE_MOSTLY_IQ2_KT, " 2.125 bpw trellis quantization", },
79+
{ "IQ3_KS", LLAMA_FTYPE_MOSTLY_IQ3_KS, " 3.19 bpw non-linear quantization", },
7980
{ "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", },
8081
{ "IQ3_K_R4", LLAMA_FTYPE_MOSTLY_IQ3_K_R4, "IQ3_K repacked", },
8182
{ "IQ3_KL", LLAMA_FTYPE_MOSTLY_IQ3_KL, " 4 bpw non-linear quantization mix",},

ggml/include/ggml.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ extern "C" {
429429
GGML_TYPE_IQ2_KT = 153,
430430
GGML_TYPE_IQ3_KT = 154,
431431
GGML_TYPE_IQ4_KT = 155,
432+
GGML_TYPE_IQ3_KS = 156,
432433

433434
GGML_TYPE_Q4_0_R8 = 202,
434435
GGML_TYPE_Q5_0_R4 = 206,
@@ -521,6 +522,7 @@ extern "C" {
521522
GGML_FTYPE_MOSTLY_IQ2_KT = 142, // except 1d tensors
522523
GGML_FTYPE_MOSTLY_IQ3_KT = 143, // except 1d tensors
523524
GGML_FTYPE_MOSTLY_IQ4_KT = 144, // except 1d tensors
525+
GGML_FTYPE_MOSTLY_IQ3_KS = 145, // except 1d tensors
524526
//
525527
GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors
526528
GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors

ggml/src/ggml-common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,14 @@ typedef struct {
649649
} block_iq3_k;
650650
static_assert(sizeof(block_iq3_k) == sizeof(ggml_half) + 2*sizeof(uint16_t) + QK_K/32 + QK_K/4 + QK_K/8, "wrong iq3_k block size/padding");
651651

652+
typedef struct {
653+
uint16_t extra;
654+
uint8_t scales[QK_K/64];
655+
uint8_t qs[QK_K/4];
656+
uint8_t qh[QK_K/8];
657+
} block_iq3_ks;
658+
static_assert(sizeof(block_iq3_ks) == sizeof(uint16_t) + QK_K/64 + QK_K/4 + QK_K/8, "wrong iq3_ks block size/padding");
659+
652660
typedef struct {
653661
ggml_half d[4];
654662
uint8_t extra[8];

ggml/src/ggml-cuda.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3465,6 +3465,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
34653465
case GGML_TYPE_IQ3_XXS:
34663466
case GGML_TYPE_IQ4_NL:
34673467
case GGML_TYPE_IQ4_XS:
3468+
case GGML_TYPE_IQ3_KS:
34683469
case GGML_TYPE_IQ4_KS:
34693470
case GGML_TYPE_IQ4_KSS:
34703471
case GGML_TYPE_IQ5_KS:

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {
599599
static constexpr int qi = QI4_XS;
600600
};
601601

602+
template<>
603+
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_KS> {
604+
static constexpr int qk = QK_K;
605+
static constexpr int qr = QR4_XS;
606+
static constexpr int qi = QI4_XS;
607+
};
608+
602609
template<>
603610
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K> {
604611
static constexpr int qk = QK_K;

ggml/src/ggml-cuda/convert.cu

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,6 +1333,51 @@ static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_
13331333
}
13341334
}
13351335

1336+
template<typename dst_t>
1337+
static __global__ void dequantize_block_iq3_ks(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
1338+
1339+
int64_t ii = blockIdx.x;
1340+
int64_t row = (QK_K * ii) / n_per_row;
1341+
const char * cx = (const char *)vx + row * row_size;
1342+
float scale = *(const ggml_half *)cx;
1343+
const block_iq3_ks * x = (const block_iq3_ks *)(cx + sizeof(ggml_half));
1344+
const int64_t i = ii - (row*n_per_row)/QK_K;
1345+
1346+
const int64_t tid = threadIdx.x;
1347+
const int64_t is = tid/16;
1348+
const int64_t il = tid%16;
1349+
dst_t * y = yy + ii*QK_K + 128*is + 2*il;
1350+
const uint8_t * qs = x[i].qs + 32*is + 2*il;
1351+
const uint8_t * qh = x[i].qh + 2*il;
1352+
uint16_t extra = x[i].extra >> 4*is;
1353+
const float d0 = scale * (int(((x[i].scales[0] >> 4*is) & 0xf) | ((extra << 4) & 0x10)) - 16);
1354+
const float d1 = scale * (int(((x[i].scales[1] >> 4*is) & 0xf) | ((extra << 3) & 0x10)) - 16);
1355+
const float d2 = scale * (int(((x[i].scales[2] >> 4*is) & 0xf) | ((extra << 2) & 0x10)) - 16);
1356+
const float d3 = scale * (int(((x[i].scales[3] >> 4*is) & 0xf) | ((extra << 1) & 0x10)) - 16);
1357+
extra >>= 8;
1358+
const int8_t * values0 = iq3nl_values + ((extra & 1) << 3);
1359+
const int8_t * values1 = iq3nl_values + ((extra & 2) << 2);
1360+
const int8_t * values2 = iq3nl_values + ((extra & 4) << 1);
1361+
const int8_t * values3 = iq3nl_values + ((extra & 8) << 0);
1362+
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
1363+
for (int j = 0; j < 2; ++j) {
1364+
uint8_t h = qh[j] >> 4*is;
1365+
y[j+ 0] = __float2bfloat16(d0 * values0[((qs[j] >> 0) & 3) | ((h << 2) & 4)]);
1366+
y[j+32] = __float2bfloat16(d1 * values1[((qs[j] >> 2) & 3) | ((h << 1) & 4)]);
1367+
y[j+64] = __float2bfloat16(d2 * values2[((qs[j] >> 4) & 3) | ((h >> 0) & 4)]);
1368+
y[j+96] = __float2bfloat16(d3 * values3[((qs[j] >> 6) & 3) | ((h >> 1) & 4)]);
1369+
}
1370+
} else {
1371+
for (int j = 0; j < 2; ++j) {
1372+
uint8_t h = qh[j] >> 4*is;
1373+
y[j+ 0] = d0 * values0[((qs[j] >> 0) & 3) | ((h << 2) & 4)];
1374+
y[j+32] = d1 * values1[((qs[j] >> 2) & 3) | ((h << 1) & 4)];
1375+
y[j+64] = d2 * values2[((qs[j] >> 4) & 3) | ((h >> 0) & 4)];
1376+
y[j+96] = d3 * values3[((qs[j] >> 6) & 3) | ((h >> 1) & 4)];
1377+
}
1378+
}
1379+
}
1380+
13361381
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
13371382
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
13381383
const int64_t k = nrows * n_per_row;
@@ -1573,6 +1618,14 @@ static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t
15731618
dequantize_block_iq3_k<<<nb, 32, 0, stream>>>(vx, y);
15741619
}
15751620

1621+
template<typename dst_t>
1622+
static void dequantize_row_iq3_ks_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1623+
const int64_t k = nrows * n_per_row;
1624+
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ3_KS, n_per_row);
1625+
const int nb = (k + QK_K - 1) / QK_K;
1626+
dequantize_block_iq3_ks<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
1627+
}
1628+
15761629
template<typename dst_t>
15771630
static void dequantize_row_iq3_k_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
15781631
const int64_t k = nrows * n_per_row;
@@ -1719,6 +1772,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
17191772
return dequantize_row_iq2_k_cuda<nv_bfloat16>;
17201773
case GGML_TYPE_IQ3_K:
17211774
return dequantize_row_iq3_k_cuda<nv_bfloat16>;
1775+
case GGML_TYPE_IQ3_KS:
1776+
return dequantize_row_iq3_ks_cuda<nv_bfloat16>;
17221777
case GGML_TYPE_IQ4_KSS:
17231778
return dequantize_row_iq4_kss_cuda<nv_bfloat16>;
17241779
case GGML_TYPE_IQ4_KS:
@@ -1821,6 +1876,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
18211876
return dequantize_row_iq2_k_cuda;
18221877
case GGML_TYPE_IQ3_K:
18231878
return dequantize_row_iq3_k_cuda;
1879+
case GGML_TYPE_IQ3_KS:
1880+
return dequantize_row_iq3_ks_cuda;
18241881
case GGML_TYPE_IQ4_K:
18251882
return dequantize_row_iq4_k_cuda;
18261883
case GGML_TYPE_IQ5_K:
@@ -1916,6 +1973,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
19161973
return dequantize_row_iq2_k_cuda;
19171974
case GGML_TYPE_IQ3_K:
19181975
return dequantize_row_iq3_k_cuda;
1976+
case GGML_TYPE_IQ3_KS:
1977+
return dequantize_row_iq3_ks_cuda;
19191978
case GGML_TYPE_IQ4_K:
19201979
return dequantize_row_iq4_k_cuda;
19211980
case GGML_TYPE_IQ5_K:

ggml/src/ggml-cuda/iqk_mmvq.cu

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,73 @@ __device__ __forceinline__ void vec_dot_iq3_k_q8_1(
11041104

11051105
}
11061106

1107+
__device__ __forceinline__ void vec_dot_iq3_ks_q8_1(
1108+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs, float * result) {
1109+
1110+
float d = __half2float(*(const half *)vbq);
1111+
const block_iq3_ks * bq3 = (const block_iq3_ks *)((const char *)vbq + sizeof(half)) + kbx;
1112+
1113+
int iqs = iiqs/4;
1114+
const int ib128 = iqs/4; // 0 or 1. 0 works on quants 0...127, 1 on quants 128...255
1115+
// Each thread processes 8 quants in each of the 4 32-blocks
1116+
const int il8 = iqs%4; // 0...3. 0 works on quants 0...7, 1 on quants 8...15, 2 on 16...23, 3 on 24...31
1117+
const int shift = 4*(il8/2);
1118+
1119+
const uint16_t * ql = (const uint16_t *)bq3->qs + 16*ib128 + 4*il8;
1120+
const uint16_t * qh = (const uint16_t *)bq3->qh + 4*il8;
1121+
1122+
int32_t aux32;
1123+
const uint8_t * aux8 = (const uint8_t *)&aux32;
1124+
1125+
uint16_t extra = bq3->extra >> 4*ib128;
1126+
uint16_t extra_v = extra >> 8;
1127+
1128+
const uint16_t * values1 = iq3k_table + ((extra_v << 6) & 0x40);
1129+
const uint16_t * values2 = iq3k_table + ((extra_v << 5) & 0x40);
1130+
const uint16_t * values3 = iq3k_table + ((extra_v << 4) & 0x40);
1131+
const uint16_t * values4 = iq3k_table + ((extra_v << 3) & 0x40);
1132+
1133+
const int * q8;
1134+
int sumi[4] = {0, 0, 0, 0};
1135+
int v;
1136+
for (int i = 0; i < 2; ++i) {
1137+
uint32_t vl = ql[2*i+0] | (ql[2*i+1] << 16);
1138+
uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) >> 4*ib128) << 2;
1139+
1140+
q8 = (const int *)bq8_1[4*ib128+0].qs + 2*il8;
1141+
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
1142+
v = int_from_table_2(aux8, values1);
1143+
sumi[0] = ggml_cuda_dp4a(v, q8[i], sumi[0]);
1144+
vl >>= 2; vh >>= 1;
1145+
1146+
q8 += sizeof(block_q8_1)/4;
1147+
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
1148+
v = int_from_table_2(aux8, values2);
1149+
sumi[1] = ggml_cuda_dp4a(v, q8[i], sumi[1]);
1150+
vl >>= 2; vh >>= 1;
1151+
1152+
q8 += sizeof(block_q8_1)/4;
1153+
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
1154+
v = int_from_table_2(aux8, values3);
1155+
sumi[2] = ggml_cuda_dp4a(v, q8[i], sumi[2]);
1156+
vl >>= 2; vh >>= 1;
1157+
1158+
q8 += sizeof(block_q8_1)/4;
1159+
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
1160+
v = int_from_table_2(aux8, values4);
1161+
sumi[3] = ggml_cuda_dp4a(v, q8[i], sumi[3]);
1162+
1163+
}
1164+
const uint16_t * sl16 = (const uint16_t *)bq3->scales;
1165+
aux32 = __vsub4(((sl16[0] | (sl16[1] << 16)) >> 4*ib128) & 0x0f0f0f0f, 0x10101010);
1166+
const int8_t * a8 = (const int8_t *)&aux32;
1167+
*result += d * (__low2float(bq8_1[4*ib128+0].ds) * (a8[0] + ((extra << 4) & 0x10)) * sumi[0] +
1168+
__low2float(bq8_1[4*ib128+1].ds) * (a8[1] + ((extra << 3) & 0x10)) * sumi[1] +
1169+
__low2float(bq8_1[4*ib128+2].ds) * (a8[2] + ((extra << 2) & 0x10)) * sumi[2] +
1170+
__low2float(bq8_1[4*ib128+3].ds) * (a8[3] + ((extra << 1) & 0x10)) * sumi[3]);
1171+
1172+
}
1173+
11071174
__device__ __forceinline__ void vec_dot_iq1_bn_q8_1(
11081175
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
11091176

@@ -1302,6 +1369,14 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
13021369
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
13031370
}
13041371

1372+
void mul_mat_vec_iq3_ks_q8_1_cuda(
1373+
const void * vx, const void * vy, float * dst, const char * ids_data,
1374+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
1375+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
1376+
1377+
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_KS, VDR_IQ3_K_Q8_1_MMVQ, vec_dot_iq3_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
1378+
}
1379+
13051380
void mul_mat_vec_iq4_kt_q8_1_cuda(
13061381
const void * vx, const void * vy, float * dst, const char * ids_data,
13071382
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,

ggml/src/ggml-cuda/iqk_mmvq.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ void mul_mat_vec_iq3_k_q8_1_cuda(
1616
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
1717
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
1818

19+
void mul_mat_vec_iq3_ks_q8_1_cuda(
20+
const void * vx, const void * vy, float * dst, const char * ids_data,
21+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
22+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
23+
1924
void mul_mat_vec_iq4_k_q8_1_cuda(
2025
const void * vx, const void * vy, float * dst, const char * ids_data,
2126
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,

ggml/src/ggml-cuda/mmq.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ void ggml_cuda_op_mul_mat_q(
9494
case GGML_TYPE_IQ4_NL:
9595
mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
9696
break;
97+
case GGML_TYPE_IQ3_KS:
98+
mul_mat_q_case<GGML_TYPE_IQ3_KS>(ctx, args, stream);
99+
break;
97100
case GGML_TYPE_IQ4_KS:
98101
mul_mat_q_case<GGML_TYPE_IQ4_KS>(ctx, args, stream);
99102
break;
@@ -196,6 +199,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
196199
case GGML_TYPE_IQ1_S_R4:
197200
case GGML_TYPE_IQ4_XS:
198201
case GGML_TYPE_IQ4_NL:
202+
case GGML_TYPE_IQ3_KS:
199203
case GGML_TYPE_IQ4_KS:
200204
case GGML_TYPE_IQ4_KS_R4:
201205
case GGML_TYPE_IQ5_KS:

0 commit comments

Comments
 (0)