Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qi = QI3_S;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KT> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
static constexpr int qi = QI4_XS;
};

//////////////////////

struct ggml_cuda_device_info {
Expand Down
20 changes: 13 additions & 7 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,13 @@ inline __device__ int nearest_int(float fval) {
return (i & 0x007fffff) - 0x00400000;
}

int __device__ __forceinline__ trellis_next_int(uint32_t& val) {
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
val = ka*val + kb;
return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126);
}

float __device__ __forceinline__ trellis_next(uint32_t& val) {
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
Expand Down Expand Up @@ -367,9 +374,9 @@ static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst
dst_t * y = yy + ii*QK_K + 8*ib;
const uint16_t * ql = (const uint16_t *)x[i].ql;
uint32_t idx = ql[ib] + 4096;
const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f;
const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 1.05f;
for (int j = 0; j < 8; ++j) {
y[j] = dl * trellis_next(idx);
y[j] = dl * trellis_next_int(idx);
}
}

Expand Down Expand Up @@ -401,9 +408,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
int64_t ii = blockIdx.x;
int64_t row = (QK_K * ii) / n_per_row;
const float * dptr = (const float *)((const char *)vx + row * row_size);
float scale = dptr[0] * 31.75f * 1.01f;
float row_av = dptr[1];
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
float scale = dptr[0] * 1.00f;
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1);
const int64_t i = ii - (row*n_per_row)/QK_K;

constexpr int kNumGroups = 64;
Expand All @@ -423,8 +429,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
int ls = ((shb[ib32] & 0xff) >> 1) - 64;
const float dl = scale * ls;
for (int j = 0; j < 4; ++j) {
y[j+0] = dl * trellis_next(idx1) + row_av;
y[j+4] = dl * trellis_next(idx2) + row_av;
y[j+0] = dl * trellis_next_int(idx1);
y[j+4] = dl * trellis_next_int(idx2);
}
}

Expand Down
89 changes: 89 additions & 0 deletions ggml/src/ggml-cuda/iqk_mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,79 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1(
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
}

__device__ __forceinline__ void vec_dot_iq4_kt_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {

constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
constexpr uint32_t km = 0x3f3f3f3f;

float scale = *(const float *)vbq;
const block_iq4_kt * bq4 = (const block_iq4_kt *)((const char *)vbq + sizeof(float)) + kbx;

// iqs is 0...28
const int ib32 = iqs/4; // Why iqs/4 ?
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
//const int8_t * q8 = bq8_1[ib32].qs;
const int ls = (bq4->qs[ib32] & 0xff) >> 1;
const float dl = scale * (ls - 64);
const uint32_t idx0 = ((bq4->qs[ib32] & 1) << 15) + 4096;
auto ql = (const uint8_t *)(bq4->qs + 8);
auto qh = ql + 64;
ql += 8*ib32;
qh += 8*(ib32%4);
const int shift1 = 8 - 4*(ib32/4);
int sumi = 0;
for (int j = 0; j < 8; ++j) {
const uint32_t sh = bq4->qs[ib32] >> (8 + 3*j);
uint32_t val = ql[j] + ((qh[j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0;
int v4 = 0;
for (int k = 0; k < 4; ++k) {
val = ka*val + kb;
//int s = val & km;
//sumi += q8[4*j+k] * ggml_cuda_dp4a(s, 0x01010101, -126);
v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
}
sumi = ggml_cuda_dp4a(v4, q8[j], sumi);
}
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
}

__device__ __forceinline__ void vec_dot_iq2_kt_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {

constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
constexpr uint32_t km = 0x3f3f3f3f;

float scale = *(const float *)vbq;
const block_iq2_kt * bq2 = (const block_iq2_kt *)((const char *)vbq + sizeof(float)) + kbx;

// iqs is 0...28
const int ib32 = iqs/4;
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
const int ls = iq4k_values[(bq2->scales[ib32%4] >> 4*(ib32/4)) & 0xf];
const float dl = scale * ls * 1.05f;
auto ql = (const uint16_t *)bq2->ql;
int sumi = 0;
for (int j = 0; j < 4; ++j) {
uint32_t val = ql[4*ib32+j] + 4096;
int v4 = 0;
for (int k = 0; k < 4; ++k) {
val = ka*val + kb;
v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
}
sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi);
v4 = 0;
for (int k = 0; k < 4; ++k) {
val = ka*val + kb;
v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
}
sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi);
}
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
}

#define VDR_IQ4_KSS_Q8_1_MMVQ 4
#define VDR_IQ4_KSS_Q8_1_MMQ 4

Expand Down Expand Up @@ -1171,6 +1244,22 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}

void mul_mat_vec_iq4_kt_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {

iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}

void mul_mat_vec_iq2_kt_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {

iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq2_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}

void mul_mat_vec_iq4_kss_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml-cuda/iqk_mmvq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,13 @@ void mul_mat_vec_iq1_s_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);

void mul_mat_vec_iq4_kt_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);

void mul_mat_vec_iq2_kt_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
8 changes: 8 additions & 0 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ4_KS_R4:
mul_mat_q_case<GGML_TYPE_IQ4_KS_R4>(ctx, args, stream);
break;
case GGML_TYPE_IQ4_KT:
mul_mat_q_case<GGML_TYPE_IQ4_KT>(ctx, args, stream);
break;
case GGML_TYPE_IQ2_KT:
mul_mat_q_case<GGML_TYPE_IQ2_KT>(ctx, args, stream);
break;
case GGML_TYPE_IQ5_KS:
mul_mat_q_case<GGML_TYPE_IQ5_KS>(ctx, args, stream);
break;
Expand Down Expand Up @@ -172,6 +178,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ4_KT:
case GGML_TYPE_IQ2_KT:
mmq_supported = true;
break;
default:
Expand Down
Loading