@@ -861,61 +861,60 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
861861 const float * y_df = (const float *) y;
862862 const half2 * y_ds = (const half2 *) y;
863863
864- mma_A A[ntx][WARP_SIZE/QI8_0] ;
865- float dA[ntx][mma_C::ne/2 ][WARP_SIZE/QI8_0] ;
864+ mma_A A[ntx];
865+ float dA[ntx][mma_C::ne/2 ];
866866
867867 const int i0 = (threadIdx .y /ntx)*rows_per_warp;
868868
869- #pragma unroll
870- for (int n = 0 ; n < ntx; ++n) {
871- #pragma unroll
872- for (int k01 = 0 ; k01 < WARP_SIZE; k01 += QI8_0) {
873- const int k0 = k00 + k01;
874-
875- A[n][k01/QI8_0].load (x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
876- }
877-
878- #pragma unroll
869+ #pragma unroll
870+ for (int k01 = 0 ; k01 < WARP_SIZE; k01 += QI8_0) {
871+ const int k0 = k00 + k01;
872+ mma_B B;
873+ float dB[mma_C::ne/2 ];
874+ B.load (y_qs + k01, MMQ_TILE_Y_K);
875+ #pragma unroll
879876 for (int l = 0 ; l < mma_C::ne/2 ; ++l) {
880- const int i = i0 + n*mma_A::I + mma_C::get_i (2 *l);
881-
882- #pragma unroll
883- for (int k01 = 0 ; k01 < WARP_SIZE; k01 += QI8_0) {
884- const int k0 = k00 + k01;
885-
886- dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
877+ const int j = mma_C::get_j (l);
878+ if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
879+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
880+ } else {
881+ dB[l] = __low2float (y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
887882 }
888883 }
889- }
890-
891- #pragma unroll
892- for (int j0 = 0 ; j0 < mmq_x; j0 += ntx*mma_C::J) {
893- #pragma unroll
894- for (int k01 = 0 ; k01 < WARP_SIZE; k01 += QI8_0) {
895- mma_B B;
896- float dB[mma_C::ne/2 ];
897-
884+ #pragma unroll
885+ for (int n = 0 ; n < ntx; ++n) {
886+ A[n].load (x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
887+ #pragma unroll
888+ for (int l = 0 ; l < mma_C::ne/2 ; ++l) {
889+ const int i = i0 + n*mma_A::I + mma_C::get_i (2 *l);
890+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
891+ }
892+ mma_C C;
893+ C.mma_K8 (A[n], B);
894+ #pragma unroll
895+ for (int l = 0 ; l < mma_C::ne; ++l) {
896+ sum[(n)*mma_C::ne + l] += C.x [l]*dA[n][l/2 ]*dB[l%2 ];
897+ }
898+ }
899+ #pragma unroll
900+ for (int j0 = ntx*mma_C::J; j0 < mmq_x; j0 += ntx*mma_C::J) {
898901 B.load (y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
899-
900- #pragma unroll
902+ #pragma unroll
901903 for (int l = 0 ; l < mma_C::ne/2 ; ++l) {
902904 const int j = j0 + mma_C::get_j (l);
903-
904- if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
905- dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
905+ if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
906+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
906907 } else {
907908 dB[l] = __low2float (y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
908909 }
909910 }
910-
911- #pragma unroll
911+ #pragma unroll
912912 for (int n = 0 ; n < ntx; ++n) {
913913 mma_C C;
914- C.mma_K8 (A[n][k01/QI8_0], B);
915-
916- #pragma unroll
914+ C.mma_K8 (A[n], B);
915+ #pragma unroll
917916 for (int l = 0 ; l < mma_C::ne; ++l) {
918- sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x [l]*dA[n][l/2 ][k01/QI8_0] *dB[l%2 ];
917+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x [l]*dA[n][l/2 ]*dB[l%2 ];
919918 }
920919 }
921920 }
@@ -2701,6 +2700,64 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
27012700 }
27022701}
27032702
2703+ // template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
2704+ // const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2705+ //
2706+ // #ifdef INT8_MMA_AVAILABLE
2707+ // int * x_qs = (int *) x_tile;
2708+ // float * x_df = (float *) (x_qs + WARP_SIZE*2);
2709+ // #else
2710+ // constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2711+ // int * x_qs = (int *) x_tile;
2712+ // float * x_df = (float *) (x_qs + txs.qs);
2713+ // #endif // INT8_MMA_AVAILABLE
2714+ //
2715+ // const int kbx = 0; // threadIdx.x / QI4_XS
2716+ // const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
2717+ //
2718+ // #pragma unroll
2719+ // for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
2720+ // int i = i0 + threadIdx.y;
2721+ //
2722+ // if (need_check) {
2723+ // i = min(i, i_max);
2724+ // }
2725+ //
2726+ // const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
2727+ //
2728+ // auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
2729+ // const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2730+ // const int2 v = get_int_from_table_16(aux_q4, values);
2731+ // const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2732+ // #ifdef INT8_MMA_AVAILABLE
2733+ // x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2734+ // x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2735+ // #else
2736+ // x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2737+ // x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2738+ // #endif // INT8_MMA_AVAILABLE
2739+ // }
2740+ //
2741+ // #pragma unroll
2742+ // for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
2743+ // int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
2744+ //
2745+ // if (need_check) {
2746+ // i = min(i, i_max);
2747+ // }
2748+ //
2749+ // const float * dptr = (const float *)(x + i*stride);
2750+ // const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
2751+ // const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;
2752+ //
2753+ // #ifdef INT8_MMA_AVAILABLE
2754+ // x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
2755+ // #else
2756+ // x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
2757+ // #endif // INT8_MMA_AVAILABLE
2758+ // }
2759+ // }
2760+
27042761template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks (
27052762 const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
27062763
@@ -2713,52 +2770,43 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
27132770 float * x_df = (float *) (x_qs + txs.qs );
27142771#endif // INT8_MMA_AVAILABLE
27152772
2716- const int kbx = 0 ; // threadIdx.x / QI4_XS
2717- const int kqsx = threadIdx .x ; // threadIdx.x % QI4_XS
2773+ const int kqsx = threadIdx .x / 4 ;
27182774
27192775#pragma unroll
2720- for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps) {
2721- int i = i0 + threadIdx .y ;
2776+ for (int i0 = 0 ; i0 < mmq_y; i0 += 4 * nwarps) {
2777+ int i = i0 + 4 * threadIdx .y + threadIdx . x % 4 ;
27222778
27232779 if (need_check) {
27242780 i = min (i, i_max);
27252781 }
27262782
2727- const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof (float )) + kbx0 + kbx;
2783+ const float * dptr = (const float *)(x + i*stride);
2784+ const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1 ) + kbx0;
2785+ const int ls = (bxi->scales [kqsx] & 254 ) - 127 ;
2786+ auto values = iq4k_values + ((bxi->scales [kqsx] & 1 ) << 4 );
27282787
2729- auto values = iq4k_values + ((bxi-> scales [kqsx/ 4 ] & 1 ) << 4 );
2730- const int aux_q4 = get_int_b4 (bxi-> qs , kqsx);
2731- const int2 v = get_int_from_table_16 (aux_q4, values );
2732- const int k0 = 8 * ( threadIdx . x / 4 ) + threadIdx . x % 4 ;
2788+ # pragma unroll
2789+ for ( int j = 0 ; j < 4 ; ++j) {
2790+ const int aux_q4 = get_int_b4 (bxi-> qs , 4 *kqsx+j );
2791+ const int2 v = get_int_from_table_16 (aux_q4, values) ;
27332792#ifdef INT8_MMA_AVAILABLE
2734- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0 ] = v.x ;
2735- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4 ] = v.y ;
2793+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8 *kqsx + j + 0 ] = v.x ;
2794+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8 *kqsx + j + 4 ] = v.y ;
27362795#else
2737- x_qs[i*(2 *WARP_SIZE + 1 ) + k0 + 0 ] = v.x ;
2738- x_qs[i*(2 *WARP_SIZE + 1 ) + k0 + 4 ] = v.y ;
2796+ x_qs[i*(2 *WARP_SIZE + 1 ) + 8 *kqsx + j + 0 ] = v.x ;
2797+ x_qs[i*(2 *WARP_SIZE + 1 ) + 8 *kqsx + j + 4 ] = v.y ;
27392798#endif // INT8_MMA_AVAILABLE
2740- }
2741-
2742- #pragma unroll
2743- for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps * 4 ) {
2744- int i = i0 + threadIdx .y * 4 + threadIdx .x / (WARP_SIZE/4 );
2745-
2746- if (need_check) {
2747- i = min (i, i_max);
27482799 }
2749-
2750- const float * dptr = (const float *)(x + i*stride);
2751- const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1 ) + kbx0;
2752- const int ls = (bxi->scales [threadIdx .x % 8 ] & 254 ) - 127 ;
2753-
27542800#ifdef INT8_MMA_AVAILABLE
2755- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx . x % 8 ] = dptr[0 ] * ls;
2801+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx ] = dptr[0 ] * ls;
27562802#else
2757- x_df[i*(WARP_SIZE/4 ) + i/4 + threadIdx . x % 8 ] = dptr[0 ] * ls;
2803+ x_df[i*(WARP_SIZE/4 ) + i/4 + kqsx ] = dptr[0 ] * ls;
27582804#endif // INT8_MMA_AVAILABLE
27592805 }
2806+
27602807}
27612808
2809+
27622810template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks_r4 (
27632811 const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
27642812
0 commit comments