Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3470,6 +3470,10 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_K_R4:
case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_K_R4:
return true;
default:
return false;
Expand Down
246 changes: 246 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,53 @@ static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq4_k_r4(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {

int64_t ii = blockIdx.x;

int64_t nblock = n_per_row/256;
int64_t row = ii/nblock;
int64_t row4 = row/4;
int64_t ir = row%4;
int64_t ibl = row4*nblock + ii%nblock;

const int tid = threadIdx.x;
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7

const block_iq4_k_r4 * x = (const block_iq4_k_r4 *)vx;
dst_t * y = yy + 256*ii + 32*ib;

const float d = __half2float(x[ibl].d[ir]);
int is = 8*ib + ir;
float dl1 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
is += 4;
float dl2 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
auto values1 = iq4k_values + (((x[ibl].extra[ir+0] >> ib) & 1) << 4);
auto values2 = iq4k_values + (((x[ibl].extra[ir+4] >> ib) & 1) << 4);
auto qs = x[ibl].qs + 64*ib + 4*ir;
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
y[il+ 0] = __float2bfloat16(dl1 * values1[qs[il+ 0] & 0xf]);
y[il+ 8] = __float2bfloat16(dl1 * values1[qs[il+ 0] >> 4]);
y[il+16] = __float2bfloat16(dl2 * values2[qs[il+16] & 0xf]);
y[il+24] = __float2bfloat16(dl2 * values2[qs[il+16] >> 4]);
y[il+ 4] = __float2bfloat16(dl1 * values1[qs[il+32] & 0xf]);
y[il+12] = __float2bfloat16(dl1 * values1[qs[il+32] >> 4]);
y[il+20] = __float2bfloat16(dl2 * values2[qs[il+48] & 0xf]);
y[il+28] = __float2bfloat16(dl2 * values2[qs[il+48] >> 4]);
} else {
y[il+ 0] = dl1 * values1[qs[il+ 0] & 0xf];
y[il+ 4] = dl1 * values1[qs[il+32] & 0xf];
y[il+ 8] = dl1 * values1[qs[il+ 0] >> 4];
y[il+12] = dl1 * values1[qs[il+32] >> 4];
y[il+16] = dl2 * values2[qs[il+16] & 0xf];
y[il+20] = dl2 * values2[qs[il+48] & 0xf];
y[il+24] = dl2 * values2[qs[il+16] >> 4];
y[il+28] = dl2 * values2[qs[il+48] >> 4];
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {

Expand Down Expand Up @@ -791,6 +838,149 @@ static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq5_k_r4(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {

int64_t ii = blockIdx.x;

int64_t nblock = n_per_row/256;
int64_t row = ii/nblock;
int64_t row4 = row/4;
int64_t ir = row%4;
int64_t ibl = row4*nblock + ii%nblock;

const int tid = threadIdx.x;
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7

const block_iq5_k_r4 * x = (const block_iq5_k_r4 *)vx;
dst_t * y = yy + 256*ii + 32*ib;

const float d = __half2float(x[ibl].d[ir]);
int is = 8*ib + ir;
float dl1 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
is += 4;
float dl2 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
auto values1 = iq5nl_values + (((x[ibl].extra[ir+0] >> ib) & 1) << 5);
auto values2 = iq5nl_values + (((x[ibl].extra[ir+4] >> ib) & 1) << 5);
auto qs = x[ibl].qs + 64*ib + 4*ir;
auto qh = x[ibl].qh + 16*ib + 4*ir;
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
y[il+ 0] = __float2bfloat16(dl1 * values1[(qs[il+ 0] & 0xf) | (((qh[il] >> 0) & 1) << 4)]);
y[il+ 4] = __float2bfloat16(dl1 * values1[(qs[il+32] & 0xf) | (((qh[il] >> 4) & 1) << 4)]);
y[il+ 8] = __float2bfloat16(dl1 * values1[(qs[il+ 0] >> 4) | (((qh[il] >> 1) & 1) << 4)]);
y[il+12] = __float2bfloat16(dl1 * values1[(qs[il+32] >> 4) | (((qh[il] >> 5) & 1) << 4)]);
y[il+16] = __float2bfloat16(dl2 * values2[(qs[il+16] & 0xf) | (((qh[il] >> 2) & 1) << 4)]);
y[il+20] = __float2bfloat16(dl2 * values2[(qs[il+48] & 0xf) | (((qh[il] >> 6) & 1) << 4)]);
y[il+24] = __float2bfloat16(dl2 * values2[(qs[il+16] >> 4) | (((qh[il] >> 3) & 1) << 4)]);
y[il+28] = __float2bfloat16(dl2 * values2[(qs[il+48] >> 4) | (((qh[il] >> 7) & 1) << 4)]);
} else {
y[il+ 0] = dl1 * values1[(qs[il+ 0] & 0xf) | (((qh[il] >> 0) & 1) << 4)];
y[il+ 4] = dl1 * values1[(qs[il+32] & 0xf) | (((qh[il] >> 4) & 1) << 4)];
y[il+ 8] = dl1 * values1[(qs[il+ 0] >> 4) | (((qh[il] >> 1) & 1) << 4)];
y[il+12] = dl1 * values1[(qs[il+32] >> 4) | (((qh[il] >> 5) & 1) << 4)];
y[il+16] = dl2 * values2[(qs[il+16] & 0xf) | (((qh[il] >> 2) & 1) << 4)];
y[il+20] = dl2 * values2[(qs[il+48] & 0xf) | (((qh[il] >> 6) & 1) << 4)];
y[il+24] = dl2 * values2[(qs[il+16] >> 4) | (((qh[il] >> 3) & 1) << 4)];
y[il+28] = dl2 * values2[(qs[il+48] >> 4) | (((qh[il] >> 7) & 1) << 4)];
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq2_k_r4(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {

int64_t ii = blockIdx.x;

int64_t nblock = n_per_row/256;
int64_t row = ii/nblock;
int64_t row4 = row/4;
int64_t ir = row%4;
int64_t ibl = row4*nblock + ii%nblock;

const int tid = threadIdx.x;
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7

const block_iq2_k_r4 * x = (const block_iq2_k_r4 *)vx;
dst_t * y = yy + 256*ii + 32*ib;

const float d = __half2float(x[ibl].d[ir]);
int is = 8*ib + ir;
float dl1 = d * (((x[ibl].scales[is%32] >> 4*(is/32)) & 0xf) - 8);
is += 4;
float dl2 = d * (((x[ibl].scales[is%32] >> 4*(is/32)) & 0xf) - 8);
auto values1 = iq2nl_values + (((x[ibl].extra[ir+0] >> ib) & 1) << 2);
auto values2 = iq2nl_values + (((x[ibl].extra[ir+4] >> ib) & 1) << 2);
auto ql = x[ibl].qs + 32*ib + 4*ir;
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
y[il+ 0] = __float2bfloat16(dl1 * values1[(ql[il+ 0] >> 0) & 3]);
y[il+ 4] = __float2bfloat16(dl1 * values1[(ql[il+ 0] >> 2) & 3]);
y[il+ 8] = __float2bfloat16(dl1 * values1[(ql[il+ 0] >> 4) & 3]);
y[il+12] = __float2bfloat16(dl1 * values1[(ql[il+ 0] >> 6) & 3]);
y[il+16] = __float2bfloat16(dl2 * values2[(ql[il+16] >> 0) & 3]);
y[il+20] = __float2bfloat16(dl2 * values2[(ql[il+16] >> 2) & 3]);
y[il+24] = __float2bfloat16(dl2 * values2[(ql[il+16] >> 4) & 3]);
y[il+28] = __float2bfloat16(dl2 * values2[(ql[il+16] >> 6) & 3]);
} else {
y[il+ 0] = dl1 * values1[(ql[il+ 0] >> 0) & 3];
y[il+ 4] = dl1 * values1[(ql[il+ 0] >> 2) & 3];
y[il+ 8] = dl1 * values1[(ql[il+ 0] >> 4) & 3];
y[il+12] = dl1 * values1[(ql[il+ 0] >> 6) & 3];
y[il+16] = dl2 * values2[(ql[il+16] >> 0) & 3];
y[il+20] = dl2 * values2[(ql[il+16] >> 2) & 3];
y[il+24] = dl2 * values2[(ql[il+16] >> 4) & 3];
y[il+28] = dl2 * values2[(ql[il+16] >> 6) & 3];
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq3_k_r4(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {

int64_t ii = blockIdx.x;

int64_t nblock = n_per_row/256;
int64_t row = ii/nblock;
int64_t row4 = row/4;
int64_t ir = row%4;
int64_t ibl = row4*nblock + ii%nblock;

const int tid = threadIdx.x;
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7

const block_iq3_k_r4 * x = (const block_iq3_k_r4 *)vx;
dst_t * y = yy + 256*ii + 32*ib;

const float d = __half2float(x[ibl].d[ir]);
int is = 8*ib + ir;
float dl1 = d * (2*((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) + 1) * ((x[ibl].scales_h[is%8] >> (is/8)) & 1 ? -1 : 1);
is += 4;
float dl2 = d * (2*((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) + 1) * ((x[ibl].scales_h[is%8] >> (is/8)) & 1 ? -1 : 1);
auto values1 = iq3nl_values + (((x[ibl].extra[ir+0] >> ib) & 1) << 3);
auto values2 = iq3nl_values + (((x[ibl].extra[ir+4] >> ib) & 1) << 3);
auto ql = x[ibl].qs + 32*ib + 4*ir;
auto qh = x[ibl].qh + 16*ib + 4*ir;
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
y[il+ 0] = __float2bfloat16(dl1 * values1[((ql[il+ 0] >> 0) & 3) | ((qh[il] << 2) & 4)]);
y[il+ 4] = __float2bfloat16(dl1 * values1[((ql[il+ 0] >> 2) & 3) | ((qh[il] << 1) & 4)]);
y[il+ 8] = __float2bfloat16(dl1 * values1[((ql[il+ 0] >> 4) & 3) | ((qh[il] << 0) & 4)]);
y[il+12] = __float2bfloat16(dl1 * values1[((ql[il+ 0] >> 6) & 3) | ((qh[il] >> 1) & 4)]);
y[il+16] = __float2bfloat16(dl2 * values2[((ql[il+16] >> 0) & 3) | ((qh[il] >> 2) & 4)]);
y[il+20] = __float2bfloat16(dl2 * values2[((ql[il+16] >> 2) & 3) | ((qh[il] >> 3) & 4)]);
y[il+24] = __float2bfloat16(dl2 * values2[((ql[il+16] >> 4) & 3) | ((qh[il] >> 4) & 4)]);
y[il+28] = __float2bfloat16(dl2 * values2[((ql[il+16] >> 6) & 3) | ((qh[il] >> 5) & 4)]);
} else {
y[il+ 0] = dl1 * values1[((ql[il+ 0] >> 0) & 3) | ((qh[il] << 2) & 4)];
y[il+ 4] = dl1 * values1[((ql[il+ 0] >> 2) & 3) | ((qh[il] << 1) & 4)];
y[il+ 8] = dl1 * values1[((ql[il+ 0] >> 4) & 3) | ((qh[il] << 0) & 4)];
y[il+12] = dl1 * values1[((ql[il+ 0] >> 6) & 3) | ((qh[il] >> 1) & 4)];
y[il+16] = dl2 * values2[((ql[il+16] >> 0) & 3) | ((qh[il] >> 2) & 4)];
y[il+20] = dl2 * values2[((ql[il+16] >> 2) & 3) | ((qh[il] >> 3) & 4)];
y[il+24] = dl2 * values2[((ql[il+16] >> 4) & 3) | ((qh[il] >> 4) & 4)];
y[il+28] = dl2 * values2[((ql[il+16] >> 6) & 3) | ((qh[il] >> 5) & 4)];
}
}


template<typename dst_t>
static __global__ void dequantize_block_iq5_ks(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
Expand Down Expand Up @@ -1202,20 +1392,52 @@ static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq3_k<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq3_k_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_K, n_per_row);
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq3_k_r4<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
}

template<typename dst_t>
static void dequantize_row_iq2_k_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_K, n_per_row);
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq2_k_r4<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
}

template<typename dst_t>
static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq4_k_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_K, n_per_row);
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_k_r4<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
}

template<typename dst_t>
static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq5_k<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq5_k_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_K, n_per_row);
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq5_k_r4<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
}

template<typename dst_t>
static void dequantize_row_iq6_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
Expand Down Expand Up @@ -1312,6 +1534,14 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
return dequantize_row_iq5_k_cuda<nv_bfloat16>;
case GGML_TYPE_IQ6_K:
return dequantize_row_iq6_k_cuda<nv_bfloat16>;
case GGML_TYPE_IQ2_K_R4:
return dequantize_row_iq2_k_r4_cuda<nv_bfloat16>;
case GGML_TYPE_IQ3_K_R4:
return dequantize_row_iq3_k_r4_cuda<nv_bfloat16>;
case GGML_TYPE_IQ4_K_R4:
return dequantize_row_iq4_k_r4_cuda<nv_bfloat16>;
case GGML_TYPE_IQ5_K_R4:
return dequantize_row_iq5_k_r4_cuda<nv_bfloat16>;
default:
return nullptr;
}
Expand Down Expand Up @@ -1394,6 +1624,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return convert_unary_cuda<float>;
case GGML_TYPE_BF16:
return convert_from_bf16_cuda;
case GGML_TYPE_IQ2_K_R4:
return dequantize_row_iq2_k_r4_cuda;
case GGML_TYPE_IQ3_K_R4:
return dequantize_row_iq3_k_r4_cuda;
case GGML_TYPE_IQ4_K_R4:
return dequantize_row_iq4_k_r4_cuda;
case GGML_TYPE_IQ5_K_R4:
return dequantize_row_iq5_k_r4_cuda;
default:
return nullptr;
}
Expand Down Expand Up @@ -1473,6 +1711,14 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return convert_unary_cuda<half>;
case GGML_TYPE_BF16:
return convert_from_bf16_cuda;
case GGML_TYPE_IQ2_K_R4:
return dequantize_row_iq2_k_r4_cuda;
case GGML_TYPE_IQ3_K_R4:
return dequantize_row_iq3_k_r4_cuda;
case GGML_TYPE_IQ4_K_R4:
return dequantize_row_iq4_k_r4_cuda;
case GGML_TYPE_IQ5_K_R4:
return dequantize_row_iq5_k_r4_cuda;
default:
return nullptr;
}
Expand Down
Loading