Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions examples/quantize-stats/quantize-stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo
float mse0 = 0, mse = 0;
auto compute = [&mutex, &counter, &mse0, &mse, values, row_size, nblock, nrows, n_per_row, chunk] () {
std::vector<char> Q(row_size);
float diff[4];
float xv[4];
float lmse0 = 0, lmse = 0;
while (true) {
std::unique_lock<std::mutex> lock(mutex);
Expand All @@ -282,25 +284,41 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo
for (int j = 0; j < 16; j += 2) {
uint16_t v0 = *(const uint16_t *)(qs + j);
int non = popcount(v0);
float diff1 = xb[j+ 0] - dl*values[qs[j+0] & 0xf];
float diff2 = xb[j+16] - dl*values[qs[j+0] >> 4];
float diff3 = xb[j+ 1] - dl*values[qs[j+1] & 0xf];
float diff4 = xb[j+17] - dl*values[qs[j+1] >> 4];
lmse0 += diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
xv[0] = xb[j+ 0]; xv[1] = xb[j+16]; xv[2] = xb[j+ 1]; xv[3] = xb[j+17];
diff[0] = xv[0] - dl*values[qs[j+0] & 0xf];
diff[1] = xv[1] - dl*values[qs[j+0] >> 4];
diff[2] = xv[2] - dl*values[qs[j+1] & 0xf];
diff[3] = xv[3] - dl*values[qs[j+1] >> 4];
float diff4 = diff[0]*diff[0] + diff[1]*diff[1] + diff[2]*diff[2] + diff[3]*diff[3];
lmse0 += diff4;
if (non%2 == 0) {
lmse += diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
lmse += diff4;
} else {
float best = std::numeric_limits<float>::max();
for (int k = 0; k < 16; k += 4) {
uint16_t v = v0 ^ (1 << k);
uint8_t v1 = v;
uint8_t v2 = v >> 8;
diff1 = xb[j+ 0] - dl*values[v1 & 0xf];
diff2 = xb[j+16] - dl*values[v1 >> 4];
diff3 = xb[j+ 1] - dl*values[v2 & 0xf];
diff4 = xb[j+17] - dl*values[v2 >> 4];
float score = diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
if (score < best) best = score;
//for (int k = 0; k < 16; k += 4) {
// uint16_t v = v0 ^ (1 << k);
// uint8_t v1 = v;
// uint8_t v2 = v >> 8;
// diff1 = xb[j+ 0] - dl*values[v1 & 0xf];
// diff2 = xb[j+16] - dl*values[v1 >> 4];
// diff3 = xb[j+ 1] - dl*values[v2 & 0xf];
// diff4 = xb[j+17] - dl*values[v2 >> 4];
// float score = diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
// if (score < best) best = score;
//}
for (int k = 0; k < 4; ++k) {
uint16_t v = (v0 >> 4*k) & 0xf;
auto pc = popcount(v);
if (v > 0 && popcount(v-1u) != pc) {
float this_diff = xv[k] - dl*values[v-1u];
float score = diff4 - diff[k]*diff[k] + this_diff*this_diff;
if (score < best) best = score;
}
if (v < 15 && popcount(v + 1u) != pc) {
float this_diff = xv[k] - dl*values[v+1u];
float score = diff4 - diff[k]*diff[k] + this_diff*this_diff;
if (score < best) best = score;
}
}
lmse += best;
}
Expand Down
1 change: 1 addition & 0 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
{ "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", },
{ "IQ4_KSS", LLAMA_FTYPE_MOSTLY_IQ4_KSS, " 4.0 bpw non-linear quantization", },
{ "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",},
{ "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",},
{ "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", },
Expand Down
2 changes: 2 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ extern "C" {
GGML_TYPE_IQ1_TN = 143,
GGML_TYPE_IQ4_KS = 144,
GGML_TYPE_IQ2_KS = 145,
GGML_TYPE_IQ4_KSS = 146,
GGML_TYPE_COUNT,
};

Expand Down Expand Up @@ -462,6 +463,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ1_TN = 136, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors
};

// available tensor operations:
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,11 @@ typedef struct {
} block_iq4_ks;
static_assert(sizeof(block_iq4_ks) == QK_K/32 + QK_K/2, "wrong iq4_ks block size/padding");

typedef struct {
uint32_t qs[QK_K/8];
} block_iq4_kss;
static_assert(sizeof(block_iq4_kss) == QK_K/8*sizeof(uint32_t), "wrong iq4_kss block size/padding");

typedef struct {
ggml_half d;
uint16_t extra;
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2829,6 +2829,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KS> {
static constexpr int qi = QI4_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KSS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
static constexpr int qi = QI4_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K> {
static constexpr int qk = QK_K;
Expand Down
43 changes: 43 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,37 @@ static __global__ void dequantize_block_iq4_ks(const void * __restrict__ vx, dst
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq4_kss(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {

int64_t ii = blockIdx.x;
int64_t row = (QK_K * ii) / n_per_row;
const char * cx = (const char *)vx + row * row_size;
float scale = *(const float *)cx;
const block_iq4_kss * x = (const block_iq4_kss *)(cx + sizeof(float));
const int64_t i = ii - (row*n_per_row)/QK_K;

const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + ii*QK_K + 32*ib + 4*il;
const uint32_t * q4 = x[i].qs + 4*ib;
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
uint8_t ls = (s32 | (s32 >> 15)) & 0xff;
const float d = scale * ((ls & 254) - 127);
const int8_t * values = iq4k_values + ((ls & 1) << 4);
uint32_t aux32[2];
aux32[0] = q4[il] & 0xfffefffe;
aux32[0] ^= (aux32[0] >> 1);
aux32[1] = ((aux32[0] >> 4) & 0x0f0f0f0f);
aux32[0] &= 0x0f0f0f0f;
const uint8_t * aux8 = (const uint8_t *)aux32;
for (int j = 0; j < 4; ++j) {
y[j+ 0] = d * values[aux8[j+0]];
y[j+16] = d * values[aux8[j+4]];
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
Expand Down Expand Up @@ -980,6 +1011,14 @@ static void dequantize_row_iq4_ks_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq4_ks<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
}

template<typename dst_t>
static void dequantize_row_iq4_kss_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KSS, n_per_row);
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_kss<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
}

template<typename dst_t>
static void dequantize_row_iq2_ks_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
Expand Down Expand Up @@ -1152,6 +1191,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ4_KS:
return dequantize_row_iq4_ks_cuda;
case GGML_TYPE_IQ4_KSS:
return dequantize_row_iq4_kss_cuda;
case GGML_TYPE_IQ2_KS:
return dequantize_row_iq2_ks_cuda;
case GGML_TYPE_IQ2_K:
Expand Down Expand Up @@ -1225,6 +1266,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ4_KS:
return dequantize_row_iq4_ks_cuda;
case GGML_TYPE_IQ4_KSS:
return dequantize_row_iq4_kss_cuda;
case GGML_TYPE_IQ2_KS:
return dequantize_row_iq2_ks_cuda;
case GGML_TYPE_IQ2_K:
Expand Down
36 changes: 36 additions & 0 deletions ggml/src/ggml-cuda/iqk_mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,35 @@ __device__ __forceinline__ float vec_dot_iq4_ks_q8_1(
return dl * __low2float(bq8_1[ib32].ds) * sumi;
}

#define VDR_IQ4_KSS_Q8_1_MMVQ 4
#define VDR_IQ4_KSS_Q8_1_MMQ 4

__device__ __forceinline__ float vec_dot_iq4_kss_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {

float scale = *(const float *)vbq;
const block_iq4_kss * bq4 = (const block_iq4_kss *)((const char *)vbq + sizeof(float)) + kbx;
const uint8_t * all_values = (const uint8_t *)iq4k_values;

// iqs is 0...28
const int ib32 = iqs/4; // Why iqs/4 ?
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
uint8_t ls = (s32 | (s32 >> 15)) & 0xff;
const float dl = scale * ((ls & 254) - 127);
int v1, v2;
int sumi = 0;
for (int j = 0; j < 4; ++j) {
uint32_t aux32 = q4[j] & 0xfffefffe;
aux32 ^= (aux32 >> 1);
get_int_from_table_16_shift(aux32, ls & 1, all_values, v1, v2);
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
}
return dl * __low2float(bq8_1[ib32].ds) * sumi;
}

#define VDR_IQ5_K_Q8_1_MMVQ 4
#define VDR_IQ5_K_Q8_1_MMQ 4

Expand Down Expand Up @@ -703,6 +732,13 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

void mul_mat_vec_iq4_kss_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KSS, VDR_IQ4_KSS_Q8_1_MMVQ, vec_dot_iq4_kss_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

void mul_mat_vec_iq2_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/iqk_mmvq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);

void mul_mat_vec_iq4_kss_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);

void mul_mat_vec_iq2_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ4_KS:
mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ4_KSS:
mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ2_KS:
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
Expand Down
Loading