Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3476,6 +3476,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ5_K_R4:
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ1_S_R4:
return true;
default:
return false;
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
static constexpr int qi = QI1_S;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S_R4> {
static constexpr int qk = 32;
static constexpr int qr = 2;
static constexpr int qi = 4;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
static constexpr int qk = QK_K;
Expand Down
49 changes: 49 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,41 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq1_s_r4(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {

int64_t ii = blockIdx.x;

int64_t nblock = n_per_row/32;
int64_t row = (8*ii)/nblock;
int64_t row4 = row/4;
int64_t ir = row%4;
int64_t ibl = (8*ii)%nblock;

const int tid = threadIdx.x;
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7

const half * dptr = (const half *)((const char *)vx + 4*row4*row_size);
const float d = (float)dptr[ir];
const block_iq1_s_r4 * x = (const block_iq1_s_r4 *)(dptr + 4) + ibl;
dst_t * y = yy + 256*ii + 32*ib + 8*il;

float dl = d*(2*((x[ib].qh[ir] >> 12) & 7) + 1);
float delta = dl * (x[ib].qh[ir] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA);

uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
grid32[0] = iq1s_grid_gpu[x[ib].qs[4*il+ir] | (((x[ib].qh[ir] >> 3*il) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;

if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
for (int j = 0; j < 8; ++j) y[j] = __float2bfloat16(dl*q[j] + delta);
} else {
for (int j = 0; j < 8; ++j) y[j] = dl*q[j] + delta;
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {

Expand Down Expand Up @@ -1398,6 +1433,14 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq1_s_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ1_S_R4, n_per_row);
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq1_s_r4<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
}

template<typename dst_t>
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
Expand Down Expand Up @@ -1651,6 +1694,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
return dequantize_row_iq5_k_r4_cuda<nv_bfloat16>;
case GGML_TYPE_IQ5_KS_R4:
return dequantize_row_iq5_ks_r4_cuda<nv_bfloat16>;
case GGML_TYPE_IQ1_S_R4:
return dequantize_row_iq1_s_r4_cuda<nv_bfloat16>;
default:
return nullptr;
}
Expand Down Expand Up @@ -1699,6 +1744,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ1_S_R4:
return dequantize_row_iq1_s_r4_cuda;
case GGML_TYPE_IQ1_M:
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ1_BN:
Expand Down Expand Up @@ -1790,6 +1837,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ1_S_R4:
return dequantize_row_iq1_s_r4_cuda;
case GGML_TYPE_IQ1_M:
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ1_BN:
Expand Down
40 changes: 40 additions & 0 deletions ggml/src/ggml-cuda/iqk_mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,38 @@ __device__ __forceinline__ void vec_dot_iq4_ks_r4_q8_1(
}
}

// TODO
__device__ __forceinline__ void vec_dot_iq1_s_r4_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {

const half * dptr = (const half *)vbq;
const block_iq1_s_r4 * bq1 = (const block_iq1_s_r4 *)(dptr + 4) + kbx;

// iqs is 0 or 2
const float d8 = __low2float(bq8_1->ds);
const int32_t * q8 = (const int *)bq8_1->qs;

int32_t grid32[2];
const int * igrid = (const int *)grid32;

int minus = 0;
for (int k = 0; k < 4; ++k) minus = ggml_cuda_dp4a(0x01010101, q8[4*(iqs/2)+k], minus);

for (int i = 0; i < 4; ++i) {
float dl = (float)dptr[i]*(2*((bq1->qh[i] >> 12) & 7) + 1) * d8;
float ml = dl * (bq1->qh[i] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA);
grid32[0] = iq1s_grid_gpu[bq1->qs[4*iqs+i] | (((bq1->qh[i] >> 3*iqs) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
int sumi = ggml_cuda_dp4a(igrid[0], q8[4*(iqs/2)+0], ggml_cuda_dp4a(igrid[1], q8[4*(iqs/2)+1], 0));
grid32[0] = iq1s_grid_gpu[bq1->qs[4*iqs+i+4] | (((bq1->qh[i] >> (3*iqs+3)) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
sumi = ggml_cuda_dp4a(igrid[0], q8[4*(iqs/2)+2], ggml_cuda_dp4a(igrid[1], q8[4*(iqs/2)+3], sumi));
result[i] += dl * sumi + ml * minus;
}
}

#define VDR_IQ4_KS_Q8_1_MMVQ 4
#define VDR_IQ4_KS_Q8_1_MMQ 4

Expand Down Expand Up @@ -1106,6 +1138,14 @@ void mul_mat_vec_iq4_ks_r4_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS_R4, 2, vec_dot_iq4_ks_r4_q8_1, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}

void mul_mat_vec_iq1_s_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {

iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S_R4, 2, vec_dot_iq1_s_r4_q8_1, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}

void mul_mat_vec_iq5_k_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/iqk_mmvq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,8 @@ void mul_mat_vec_iq5_ks_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);

void mul_mat_vec_iq1_s_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ1_S:
mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
break;
case GGML_TYPE_IQ1_S_R4:
mul_mat_q_case<GGML_TYPE_IQ1_S_R4>(ctx, args, stream);
break;
case GGML_TYPE_IQ4_XS:
mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
break;
Expand Down Expand Up @@ -150,6 +153,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_S_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_KS:
Expand All @@ -174,6 +178,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
if (int8_mma_available(cc)) {
return true;
}
if (type == GGML_TYPE_IQ1_S_R4) {
return false;
}

if (cc < MIN_CC_DP4A) {
return false;
Expand Down
80 changes: 80 additions & 0 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_IQ3_S:
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_S_R4:
return MMQ_Q8_1_DS_LAYOUT_DS4;
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
Expand Down Expand Up @@ -186,6 +187,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ3_XXS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ3_S : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ1_S : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ1_S_R4: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0;
Expand Down Expand Up @@ -231,6 +233,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ3_XXS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ3_S : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ1_S : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ1_S_R4: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0;
Expand Down Expand Up @@ -318,6 +321,74 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
}

template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s_r4(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {

#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + 2*WARP_SIZE);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE

const int kbx = threadIdx.x / 4;
const int kqsx = threadIdx.x % 4;

int32_t grid32[2];

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;

if (need_check) {
i = min(i, i_max);
}
const int i4 = i/4;
const int ir = i%4;

const block_iq1_s_r4 * bxi = (const block_iq1_s_r4 *)(x + 4*i4*stride + 4*sizeof(half)) + kbx0 + kbx;

grid32[0] = iq1s_grid_gpu[bxi->qs[4*kqsx+ir] | (((bxi->qh[ir] >> 3*kqsx) & 7) << 8)];
grid32[1] = ((grid32[0] >> 4) & 0x0f0f0f0f) << 3;
grid32[0] = (grid32[0] & 0x0f0f0f0f) << 3;
const int shift = bxi->qh[ir] & 0x8000 ? 0x09090909 : 0x07070707;

#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kbx + 2*kqsx + 0] = __vsubss4(grid32[0], shift);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kbx + 2*kqsx + 1] = __vsubss4(grid32[1], shift);
#else
// TODO
//x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
#endif // INT8_MMA_AVAILABLE
}

const int blocks_per_tile_x_row = WARP_SIZE / 4;
const int kbxd = threadIdx.x % blocks_per_tile_x_row;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + threadIdx.y * 4 + threadIdx.x / blocks_per_tile_x_row;

if (need_check) {
i = min(i, i_max);
}
const int i4 = i/4;
const int ir = i%4;

const half * dptr = (const half *)(x + 4*i4*stride);
const block_iq1_s_r4 * bxi = (const block_iq1_s_r4 *)(dptr + 4) + kbx0 + kbxd;

#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = 0.125f * __half2float(dptr[ir]) * (((bxi->qh[ir] >> 11) & 14) + 1);
#else
// TODO
//x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
#endif // INT8_MMA_AVAILABLE
}
}

template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
Expand Down Expand Up @@ -3132,6 +3203,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S_R4> {
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s_r4<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
Expand Down Expand Up @@ -3656,6 +3735,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4);

// -------------------------------------------------------------------------------------------------------------------------

Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
case GGML_TYPE_IQ5_KS_R4:
mul_mat_vec_iq5_ks_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ1_S_R4:
mul_mat_vec_iq1_s_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
default:
GGML_ABORT("fatal error");
break;
Expand Down Expand Up @@ -679,6 +682,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ5_K_R4:
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ1_S_R4:
return true;
default:
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.

#include "../mmq.cuh"

DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4);