Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 115 additions & 67 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -861,61 +861,60 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const float * y_df = (const float *) y;
const half2 * y_ds = (const half2 *) y;

mma_A A[ntx][WARP_SIZE/QI8_0];
float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
mma_A A[ntx];
float dA[ntx][mma_C::ne/2];

const int i0 = (threadIdx.y/ntx)*rows_per_warp;

#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
const int k0 = k00 + k01;

A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
}

#pragma unroll
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
const int k0 = k00 + k01;
mma_B B;
float dB[mma_C::ne/2];
B.load(y_qs + k01, MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);

#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
const int k0 = k00 + k01;

dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
const int j = mma_C::get_j(l);
if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
} else {
dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
}

#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
mma_B B;
float dB[mma_C::ne/2];

#pragma unroll
for (int n = 0; n < ntx; ++n) {
A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
}
mma_C C;
C.mma_K8(A[n], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
}
}
#pragma unroll
for (int j0 = ntx*mma_C::J; j0 < mmq_x; j0 += ntx*mma_C::J) {
B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);

#pragma unroll
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);

if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
} else {
dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}

#pragma unroll
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C;
C.mma_K8(A[n][k01/QI8_0], B);

#pragma unroll
C.mma_K8(A[n], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
}
}
}
Expand Down Expand Up @@ -2701,6 +2700,64 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
}

//template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
// const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
//
//#ifdef INT8_MMA_AVAILABLE
// int * x_qs = (int *) x_tile;
// float * x_df = (float *) (x_qs + WARP_SIZE*2);
//#else
// constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
// int * x_qs = (int *) x_tile;
// float * x_df = (float *) (x_qs + txs.qs);
//#endif // INT8_MMA_AVAILABLE
//
// const int kbx = 0; // threadIdx.x / QI4_XS
// const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
//
//#pragma unroll
// for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
// int i = i0 + threadIdx.y;
//
// if (need_check) {
// i = min(i, i_max);
// }
//
// const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
//
// auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
// const int aux_q4 = get_int_b4(bxi->qs, kqsx);
// const int2 v = get_int_from_table_16(aux_q4, values);
// const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
//#ifdef INT8_MMA_AVAILABLE
// x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
// x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
//#else
// x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
// x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
//#endif // INT8_MMA_AVAILABLE
// }
//
//#pragma unroll
// for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
// int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
//
// if (need_check) {
// i = min(i, i_max);
// }
//
// const float * dptr = (const float *)(x + i*stride);
// const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
// const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;
//
//#ifdef INT8_MMA_AVAILABLE
// x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
//#else
// x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
//#endif // INT8_MMA_AVAILABLE
// }
//}

template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {

Expand All @@ -2713,52 +2770,43 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE

const int kbx = 0; // threadIdx.x / QI4_XS
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
const int kqsx = threadIdx.x / 4;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;

if (need_check) {
i = min(i, i_max);
}

const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
const float * dptr = (const float *)(x + i*stride);
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
const int ls = (bxi->scales[kqsx] & 254) - 127;
auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4);

auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
const int2 v = get_int_from_table_16(aux_q4, values);
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int aux_q4 = get_int_b4(bxi->qs, 4*kqsx+j);
const int2 v = get_int_from_table_16(aux_q4, values);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
#endif // INT8_MMA_AVAILABLE
}

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);

if (need_check) {
i = min(i, i_max);
}

const float * dptr = (const float *)(x + i*stride);
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;

#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[0] * ls;
#else
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[0] * ls;
#endif // INT8_MMA_AVAILABLE
}

}


template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks_r4(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {

Expand Down