Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,7 @@ static void ggml_cuda_op_mul_mat_cublas(
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = row_diff*ne00;
src0_as_f16.alloc(ne);
to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
to_fp16_cuda(src0_dd_i, src0_as_f16.get(), row_diff, ne00, stream);
}
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();

Expand All @@ -1245,7 +1245,7 @@ static void ggml_cuda_op_mul_mat_cublas(
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = src1_ncols*ne10;
src1_as_f16.alloc(ne);
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), src1_ncols, ne10, stream);
}
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
Expand All @@ -1264,7 +1264,7 @@ static void ggml_cuda_op_mul_mat_cublas(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff, src1_ncols, stream);
} else {
ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));
Expand All @@ -1273,13 +1273,13 @@ static void ggml_cuda_op_mul_mat_cublas(
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
GGML_ASSERT(to_fp32_cuda != nullptr);
src0_ddq_as_f32.alloc(row_diff*ne00);
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff, ne00, stream);
}
if (src1->type != GGML_TYPE_F32) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
GGML_ASSERT(to_fp32_cuda != nullptr);
src1_ddq_as_f32.alloc(src1_ncols*ne10);
to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols, ne10, stream);
}

const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
Expand Down Expand Up @@ -1779,7 +1779,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
const int64_t ne_src1 = ggml_nelements(src1);
src1_f16_alloc.alloc(ne_src1);
GGML_ASSERT(to_fp16_cuda != nullptr);
to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ggml_nrows(src1), src1->ne[0], main_stream);
}
half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();

Expand Down Expand Up @@ -1894,7 +1894,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co

if (dst->op_params[0] == GGML_PREC_DEFAULT) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
to_fp32_cuda(dst_f16.get(), dst_ddf, ggml_nrows(dst), dst->ne[0], main_stream);
}
}

Expand Down Expand Up @@ -2790,6 +2790,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ1_TN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_TN:
return true;
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_BN> {
static constexpr int qi = QI1_BN;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_TN> {
static constexpr int qk = QK_IQ1BN;
static constexpr int qr = QR1_BN;
static constexpr int qi = QI1_BN;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> {
static constexpr int qk = QK_IQ1BN;
Expand Down
133 changes: 106 additions & 27 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,46 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq1_tn(const void * __restrict__ vx, dst_t * __restrict__ yy,
int64_t n_per_row, int64_t row_size) {

int64_t ii = blockIdx.x;
int64_t row = (QK_K * ii) / n_per_row;
const char * cx = (const char *)vx + row * row_size;
float scale = *(const half *)cx;
const block_iq1_bn * x = (const block_iq1_bn *)(cx + sizeof(half));

static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};

//#define COMPUTE_VS(v) 3*v >> 8
#define COMPUTE_VS(v) (v + (v >> 1)) >> 7

const int tid = threadIdx.x;
const int il = tid/4; // 0...7
const int ib = tid%4; // 0...3
dst_t * y = yy + ii*QK_K + 64*ib + 8*il;
const int i16 = il/2;
int64_t i = QK_K/QK_IQ1BN * (ii - (row*n_per_row)/QK_K) + ib;
uint8_t q = x[i].ql[3*i16+2*(il%2)];
for (int j = 0; j < 5; ++j) {
uint8_t v = k_mult[j]*q;
int8_t vs = COMPUTE_VS(v);
y[2*(il%2)+j] = scale*(vs - 1);
}
q = x[i].ql[3*i16+1];
for (int j = 0; j < 2; ++j) {
uint8_t v = k_mult[3*(il%2)+j]*q;
int8_t vs = COMPUTE_VS(v);
y[5*(1-(il%2))+j] = scale*(vs-1);
}
uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q;
int8_t vs = COMPUTE_VS(v);
y[7] = scale*(vs - 1);

#undef COMPUTE_VS
}

template<typename dst_t>
static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) {

Expand Down Expand Up @@ -675,12 +715,14 @@ static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}

static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
const bool need_check = false;
Expand All @@ -692,149 +734,181 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half *
}

template<typename dst_t>
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
}

template<typename dst_t>
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
}

template<typename dst_t>
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq1_bn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq1_bn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb64 = k / QK_IQ1BN;
const int nb = (k + 255) / 256;
dequantize_block_iq1_bn<<<nb, 32, 0, stream>>>(vx, y, nb64);
}

template<typename dst_t>
static void dequantize_row_iq2_bn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq1_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ1_TN, n_per_row);
const int nb = (k + 255) / 256;
dequantize_block_iq1_tn<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
}

template<typename dst_t>
static void dequantize_row_iq2_bn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb64 = k / QK_IQ1BN;
const int nb = (k + 255) / 256;
dequantize_block_iq2_bn<<<nb, 32, 0, stream>>>(vx, y, nb64);
}

template<typename dst_t>
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq2_k<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq3_k<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq5_k<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq6_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
static void dequantize_row_iq6_k_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq6_k<<<nb, 32, 0, stream>>>(vx, y);
}
Expand All @@ -853,7 +927,8 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
}

template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
Expand Down Expand Up @@ -899,6 +974,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ1_BN:
return dequantize_row_iq1_bn_cuda;
case GGML_TYPE_IQ1_TN:
return dequantize_row_iq1_tn_cuda;
case GGML_TYPE_IQ2_BN:
return dequantize_row_iq2_bn_cuda;
case GGML_TYPE_IQ4_NL:
Expand Down Expand Up @@ -962,6 +1039,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ1_BN:
return dequantize_row_iq1_bn_cuda;
case GGML_TYPE_IQ1_TN:
return dequantize_row_iq1_tn_cuda;
case GGML_TYPE_IQ2_BN:
return dequantize_row_iq2_bn_cuda;
case GGML_TYPE_IQ4_NL:
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/convert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256

template<typename T>
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t nrows, int64_t n_per_row, cudaStream_t stream);

typedef to_t_cuda_t<float> to_fp32_cuda_t;
typedef to_t_cuda_t<half> to_fp16_cuda_t;
Expand Down
Loading