Skip to content

Commit 5714f00

Browse files
simplify code, make functions constexpr
1 parent f3e5be0 commit 5714f00

File tree

2 files changed

+46
-57
lines changed

2 files changed

+46
-57
lines changed

ggml-cuda/common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
643643
static constexpr int qi = QI3_S;
644644
};
645645

646-
static int get_mmq_x_max_host(const int cc) {
646+
static constexpr int get_mmq_x_max_host(int cc) {
647647
#ifdef CUDA_USE_TENSOR_CORES
648648
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
649649
#else
@@ -652,7 +652,7 @@ static int get_mmq_x_max_host(const int cc) {
652652
}
653653

654654
// Round rows to this value for --split-mode row:
655-
static int get_mmq_y_host(const int cc) {
655+
static constexpr int get_mmq_y_host(int cc) {
656656
return cc >= CC_VOLTA ? 128 : 64;
657657
}
658658

ggml-cuda/mmq.cuh

Lines changed: 44 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -67,26 +67,18 @@ static constexpr __device__ int get_mmq_y_device() {
6767
#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
6868
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
6969

70-
#define GET_MMQ_DP4A_TXS_BODY \
71-
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : \
72-
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : \
73-
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 : \
74-
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 : \
75-
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : \
76-
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : \
77-
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : \
78-
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : \
79-
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : \
80-
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : \
81-
tile_x_sizes{0, 0, 0}
82-
83-
static tile_x_sizes mmq_get_dp4a_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
84-
GET_MMQ_DP4A_TXS_BODY;
85-
}
86-
87-
template <int mmq_y>
88-
static constexpr __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes_device(ggml_type type) {
89-
GET_MMQ_DP4A_TXS_BODY;
70+
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
71+
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
72+
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
73+
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
74+
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
75+
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
76+
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
77+
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
78+
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
79+
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
80+
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
81+
tile_x_sizes{0, 0, 0};
9082
}
9183

9284
#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
@@ -111,21 +103,18 @@ static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
111103
static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding.");
112104
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
113105

114-
#define MMQ_MMA_GET_TILE_X_K_BODY \
115-
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : \
116-
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : \
117-
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 : \
118-
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 : \
119-
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : \
120-
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : \
121-
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : \
122-
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : \
123-
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : \
124-
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \
125-
0
126-
127106
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
128-
MMQ_MMA_GET_TILE_X_K_BODY;
107+
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
108+
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
109+
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
110+
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
111+
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
112+
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
113+
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
114+
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
115+
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
116+
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
117+
0;
129118
}
130119

131120
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
@@ -154,7 +143,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
154143
int * x_qs = (int *) x_tile;
155144
float * x_df = (float *) (x_qs + WARP_SIZE);
156145
#else
157-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_0);
146+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
158147
int * x_qs = (int *) x_tile;
159148
float * x_df = (float *) (x_qs + txs.qs);
160149
#endif // INT8_MMA_AVAILABLE
@@ -204,7 +193,7 @@ template <int mmq_x, int mmq_y, int nwarps>
204193
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
205194
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
206195

207-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_0);
196+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
208197
const int * x_qs = (const int *) x;
209198
const float * x_df = (const float *) x_qs + txs.qs;
210199
const int * y_qs = (const int *) y + 4;
@@ -317,7 +306,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
317306
int * x_qs = (int *) x_tile;
318307
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
319308
#else
320-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_1);
309+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
321310
int * x_qs = (int *) x_tile;
322311
half2 * x_dm = (half2 *) (x_qs + txs.qs);
323312
#endif // INT8_MMA_AVAILABLE
@@ -367,7 +356,7 @@ template <int mmq_x, int mmq_y, int nwarps>
367356
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
368357
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
369358

370-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_1);
359+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
371360
const int * x_qs = (const int *) x;
372361
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
373362
const int * y_qs = (const int *) y + 4;
@@ -479,7 +468,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
479468
int * x_qs = (int *) x_tile;
480469
float * x_df = (float *) (x_qs + WARP_SIZE*2);
481470
#else
482-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_0);
471+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
483472
int * x_qs = (int *) x_tile;
484473
float * x_df = (float *) (x_qs + txs.qs);
485474
#endif // INT8_MMA_AVAILABLE
@@ -548,7 +537,7 @@ template <int mmq_x, int mmq_y, int nwarps>
548537
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
549538
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
550539

551-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_0);
540+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
552541
const int * x_qs = (const int *) x;
553542
const float * x_df = (const float *) x_qs + txs.qs;
554543
const int * y_qs = (const int *) y + 4;
@@ -644,7 +633,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
644633
int * x_qs = (int *) x_tile;
645634
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
646635
#else
647-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_1);
636+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
648637
int * x_qs = (int *) x_tile;
649638
half2 * x_dm = (half2 *) (x_qs + txs.qs);
650639
#endif // INT8_MMA_AVAILABLE
@@ -711,7 +700,7 @@ template <int mmq_x, int mmq_y, int nwarps>
711700
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
712701
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
713702

714-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_1);
703+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
715704
const int * x_qs = (const int *) x;
716705
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
717706
const int * y_qs = (const int *) y + 4;
@@ -808,7 +797,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
808797
int * x_qs = (int *) x_tile;
809798
float * x_df = (float *) (x_tile + WARP_SIZE);
810799
#else
811-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q8_0);
800+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
812801
int * x_qs = (int *) x_tile;
813802
float * x_df = (float *) (x_qs + txs.qs);
814803
#endif // INT8_MMA_AVAILABLE
@@ -858,7 +847,7 @@ template <int mmq_x, int mmq_y, int nwarps>
858847
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
859848
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
860849

861-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q8_0);
850+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
862851
const int * x_qs = (const int *) x;
863852
const float * x_df = (const float *) x_qs + txs.qs;
864853
const int * y_qs = (const int *) y + 4;
@@ -954,7 +943,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
954943
int * x_qs = (int *) x_tile;
955944
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
956945
#else
957-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q2_K);
946+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
958947
int * x_qs = (int *) x_tile;
959948
half2 * x_dm = (half2 *) (x_qs + txs.qs);
960949
#endif // INT8_MMA_AVAILABLE
@@ -1013,7 +1002,7 @@ template <int mmq_x, int mmq_y, int nwarps>
10131002
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
10141003
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
10151004

1016-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q2_K);
1005+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
10171006
const int * x_qs = (const int *) x;
10181007
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
10191008
const int * y_qs = (const int *) y + 4;
@@ -1135,7 +1124,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
11351124
float * x_df = (float *) (x_qs + WARP_SIZE*2);
11361125
int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K);
11371126
#else
1138-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q3_K);
1127+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
11391128
int * x_qs = (int *) x_tile;
11401129
float * x_df = (float *) (x_qs + txs.qs);
11411130
int * x_sc = (int *) (x_df + txs.dm);
@@ -1233,7 +1222,7 @@ template <int mmq_x, int mmq_y, int nwarps>
12331222
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
12341223
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
12351224

1236-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q3_K);
1225+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
12371226
const int * x_qs = (const int *) x;
12381227
const float * x_df = (const float *) x_qs + txs.qs;
12391228
const int * x_sc = (const int *) x_df + txs.dm;
@@ -1361,7 +1350,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
13611350
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
13621351
int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K);
13631352
#else
1364-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_K);
1353+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
13651354
int * x_qs = (int *) x_tile;
13661355
half2 * x_dm = (half2 *) (x_qs + txs.qs);
13671356
int * x_sc = (int *) (x_dm + txs.dm);
@@ -1437,7 +1426,7 @@ template <int mmq_x, int mmq_y, int nwarps>
14371426
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
14381427
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
14391428

1440-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_K);
1429+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
14411430
const int * x_qs = (const int *) x;
14421431
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
14431432
const int * x_sc = (const int *) x_dm + txs.dm;
@@ -1578,7 +1567,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
15781567
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
15791568
int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K);
15801569
#else
1581-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_K);
1570+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
15821571
int * x_qs = (int *) x_tile;
15831572
half2 * x_dm = (half2 *) (x_qs + txs.qs);
15841573
int * x_sc = (int *) (x_dm + txs.dm);
@@ -1668,7 +1657,7 @@ template <int mmq_x, int mmq_y, int nwarps>
16681657
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
16691658
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
16701659

1671-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_K);
1660+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
16721661
const int * x_qs = (const int *) x;
16731662
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
16741663
const int * x_sc = (const int *) x_dm + txs.dm;
@@ -1800,7 +1789,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
18001789
float * x_df = (float *) (x_qs + WARP_SIZE*2);
18011790
int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
18021791
#else
1803-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q6_K);
1792+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
18041793
int * x_qs = (int *) x_tile;
18051794
float * x_df = (float *) (x_qs + txs.qs);
18061795
int * x_sc = (int *) (x_df + txs.dm);
@@ -1882,7 +1871,7 @@ template <int mmq_x, int mmq_y, int nwarps>
18821871
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
18831872
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
18841873

1885-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q6_K);
1874+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
18861875
const int * x_qs = (const int *) x;
18871876
const float * x_df = (const float *) x_qs + txs.qs;
18881877
const int * x_sc = (const int *) x_df + txs.dm;
@@ -2422,7 +2411,7 @@ struct mmq_args {
24222411

24232412
template<ggml_type type>
24242413
static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
2425-
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host(type, mmq_y);
2414+
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
24262415
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
24272416
const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
24282417
const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);

0 commit comments

Comments
 (0)