Skip to content

Commit 46f2e5d

Browse files
ikawrakowIwan Kawrakow
andauthored
Minor CUDA PP speed improvement (#567)
* Slightly better q8_0_q8_1 kerneel and iqk_ks tile loading * Minor --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent b256675 commit 46f2e5d

File tree

1 file changed

+115
-67
lines changed

1 file changed

+115
-67
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 115 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -861,61 +861,60 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
861861
const float * y_df = (const float *) y;
862862
const half2 * y_ds = (const half2 *) y;
863863

864-
mma_A A[ntx][WARP_SIZE/QI8_0];
865-
float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
864+
mma_A A[ntx];
865+
float dA[ntx][mma_C::ne/2];
866866

867867
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
868868

869-
#pragma unroll
870-
for (int n = 0; n < ntx; ++n) {
871-
#pragma unroll
872-
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
873-
const int k0 = k00 + k01;
874-
875-
A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
876-
}
877-
878-
#pragma unroll
869+
#pragma unroll
870+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
871+
const int k0 = k00 + k01;
872+
mma_B B;
873+
float dB[mma_C::ne/2];
874+
B.load(y_qs + k01, MMQ_TILE_Y_K);
875+
#pragma unroll
879876
for (int l = 0; l < mma_C::ne/2; ++l) {
880-
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
881-
882-
#pragma unroll
883-
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
884-
const int k0 = k00 + k01;
885-
886-
dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
877+
const int j = mma_C::get_j(l);
878+
if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
879+
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
880+
} else {
881+
dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
887882
}
888883
}
889-
}
890-
891-
#pragma unroll
892-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
893-
#pragma unroll
894-
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
895-
mma_B B;
896-
float dB[mma_C::ne/2];
897-
884+
#pragma unroll
885+
for (int n = 0; n < ntx; ++n) {
886+
A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
887+
#pragma unroll
888+
for (int l = 0; l < mma_C::ne/2; ++l) {
889+
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
890+
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
891+
}
892+
mma_C C;
893+
C.mma_K8(A[n], B);
894+
#pragma unroll
895+
for (int l = 0; l < mma_C::ne; ++l) {
896+
sum[(n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
897+
}
898+
}
899+
#pragma unroll
900+
for (int j0 = ntx*mma_C::J; j0 < mmq_x; j0 += ntx*mma_C::J) {
898901
B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
899-
900-
#pragma unroll
902+
#pragma unroll
901903
for (int l = 0; l < mma_C::ne/2; ++l) {
902904
const int j = j0 + mma_C::get_j(l);
903-
904-
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
905-
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
905+
if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
906+
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
906907
} else {
907908
dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
908909
}
909910
}
910-
911-
#pragma unroll
911+
#pragma unroll
912912
for (int n = 0; n < ntx; ++n) {
913913
mma_C C;
914-
C.mma_K8(A[n][k01/QI8_0], B);
915-
916-
#pragma unroll
914+
C.mma_K8(A[n], B);
915+
#pragma unroll
917916
for (int l = 0; l < mma_C::ne; ++l) {
918-
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
917+
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
919918
}
920919
}
921920
}
@@ -2701,6 +2700,64 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
27012700
}
27022701
}
27032702

2703+
//template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
2704+
// const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2705+
//
2706+
//#ifdef INT8_MMA_AVAILABLE
2707+
// int * x_qs = (int *) x_tile;
2708+
// float * x_df = (float *) (x_qs + WARP_SIZE*2);
2709+
//#else
2710+
// constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2711+
// int * x_qs = (int *) x_tile;
2712+
// float * x_df = (float *) (x_qs + txs.qs);
2713+
//#endif // INT8_MMA_AVAILABLE
2714+
//
2715+
// const int kbx = 0; // threadIdx.x / QI4_XS
2716+
// const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
2717+
//
2718+
//#pragma unroll
2719+
// for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
2720+
// int i = i0 + threadIdx.y;
2721+
//
2722+
// if (need_check) {
2723+
// i = min(i, i_max);
2724+
// }
2725+
//
2726+
// const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
2727+
//
2728+
// auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
2729+
// const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2730+
// const int2 v = get_int_from_table_16(aux_q4, values);
2731+
// const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2732+
//#ifdef INT8_MMA_AVAILABLE
2733+
// x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2734+
// x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2735+
//#else
2736+
// x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2737+
// x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2738+
//#endif // INT8_MMA_AVAILABLE
2739+
// }
2740+
//
2741+
//#pragma unroll
2742+
// for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
2743+
// int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
2744+
//
2745+
// if (need_check) {
2746+
// i = min(i, i_max);
2747+
// }
2748+
//
2749+
// const float * dptr = (const float *)(x + i*stride);
2750+
// const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
2751+
// const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;
2752+
//
2753+
//#ifdef INT8_MMA_AVAILABLE
2754+
// x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
2755+
//#else
2756+
// x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
2757+
//#endif // INT8_MMA_AVAILABLE
2758+
// }
2759+
//}
2760+
27042761
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
27052762
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
27062763

@@ -2713,52 +2770,43 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
27132770
float * x_df = (float *) (x_qs + txs.qs);
27142771
#endif // INT8_MMA_AVAILABLE
27152772

2716-
const int kbx = 0; // threadIdx.x / QI4_XS
2717-
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
2773+
const int kqsx = threadIdx.x / 4;
27182774

27192775
#pragma unroll
2720-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
2721-
int i = i0 + threadIdx.y;
2776+
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
2777+
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
27222778

27232779
if (need_check) {
27242780
i = min(i, i_max);
27252781
}
27262782

2727-
const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
2783+
const float * dptr = (const float *)(x + i*stride);
2784+
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
2785+
const int ls = (bxi->scales[kqsx] & 254) - 127;
2786+
auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4);
27282787

2729-
auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
2730-
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2731-
const int2 v = get_int_from_table_16(aux_q4, values);
2732-
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2788+
#pragma unroll
2789+
for (int j = 0; j < 4; ++j) {
2790+
const int aux_q4 = get_int_b4(bxi->qs, 4*kqsx+j);
2791+
const int2 v = get_int_from_table_16(aux_q4, values);
27332792
#ifdef INT8_MMA_AVAILABLE
2734-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2735-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2793+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
2794+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
27362795
#else
2737-
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2738-
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2796+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
2797+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
27392798
#endif // INT8_MMA_AVAILABLE
2740-
}
2741-
2742-
#pragma unroll
2743-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
2744-
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
2745-
2746-
if (need_check) {
2747-
i = min(i, i_max);
27482799
}
2749-
2750-
const float * dptr = (const float *)(x + i*stride);
2751-
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
2752-
const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;
2753-
27542800
#ifdef INT8_MMA_AVAILABLE
2755-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
2801+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[0] * ls;
27562802
#else
2757-
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
2803+
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[0] * ls;
27582804
#endif // INT8_MMA_AVAILABLE
27592805
}
2806+
27602807
}
27612808

2809+
27622810
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks_r4(
27632811
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
27642812

0 commit comments

Comments
 (0)