@@ -67,26 +67,18 @@ static constexpr __device__ int get_mmq_y_device() {
67
67
#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8 }
68
68
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8 }
69
69
70
- #define GET_MMQ_DP4A_TXS_BODY \
71
- return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : \
72
- type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : \
73
- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 : \
74
- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 : \
75
- type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : \
76
- type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : \
77
- type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : \
78
- type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : \
79
- type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : \
80
- type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : \
81
- tile_x_sizes{0 , 0 , 0 }
82
-
83
- static tile_x_sizes mmq_get_dp4a_tile_x_sizes_host (const ggml_type type, const int mmq_y) {
84
- GET_MMQ_DP4A_TXS_BODY;
85
- }
86
-
87
- template <int mmq_y>
88
- static constexpr __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes_device (ggml_type type) {
89
- GET_MMQ_DP4A_TXS_BODY;
70
+ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes (ggml_type type, int mmq_y) {
71
+ return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
72
+ type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
73
+ type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
74
+ type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
75
+ type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
76
+ type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
77
+ type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
78
+ type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
79
+ type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
80
+ type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
81
+ tile_x_sizes{0 , 0 , 0 };
90
82
}
91
83
92
84
#define MMQ_MMA_TILE_X_K_Q4_0 (1 *WARP_SIZE + WARP_SIZE/QI4_0 + 4 )
@@ -111,21 +103,18 @@ static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
111
103
static_assert (MMQ_MMA_TILE_X_K_Q5_K % 8 == 4 , " Wrong padding." );
112
104
static_assert (MMQ_MMA_TILE_X_K_Q6_K % 8 == 4 , " Wrong padding." );
113
105
114
- #define MMQ_MMA_GET_TILE_X_K_BODY \
115
- return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : \
116
- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : \
117
- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 : \
118
- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 : \
119
- type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : \
120
- type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : \
121
- type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : \
122
- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : \
123
- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : \
124
- type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \
125
- 0
126
-
127
106
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k (ggml_type type) {
128
- MMQ_MMA_GET_TILE_X_K_BODY;
107
+ return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
108
+ type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
109
+ type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
110
+ type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
111
+ type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
112
+ type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
113
+ type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
114
+ type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
115
+ type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
116
+ type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
117
+ 0 ;
129
118
}
130
119
131
120
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
@@ -154,7 +143,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
154
143
int * x_qs = (int *) x_tile;
155
144
float * x_df = (float *) (x_qs + WARP_SIZE);
156
145
#else
157
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q4_0);
146
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q4_0, mmq_y );
158
147
int * x_qs = (int *) x_tile;
159
148
float * x_df = (float *) (x_qs + txs.qs );
160
149
#endif // INT8_MMA_AVAILABLE
@@ -204,7 +193,7 @@ template <int mmq_x, int mmq_y, int nwarps>
204
193
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a (
205
194
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
206
195
207
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q4_0);
196
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q4_0, mmq_y );
208
197
const int * x_qs = (const int *) x;
209
198
const float * x_df = (const float *) x_qs + txs.qs ;
210
199
const int * y_qs = (const int *) y + 4 ;
@@ -317,7 +306,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
317
306
int * x_qs = (int *) x_tile;
318
307
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
319
308
#else
320
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q4_1);
309
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q4_1, mmq_y );
321
310
int * x_qs = (int *) x_tile;
322
311
half2 * x_dm = (half2 *) (x_qs + txs.qs );
323
312
#endif // INT8_MMA_AVAILABLE
@@ -367,7 +356,7 @@ template <int mmq_x, int mmq_y, int nwarps>
367
356
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a (
368
357
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
369
358
370
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q4_1);
359
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q4_1, mmq_y );
371
360
const int * x_qs = (const int *) x;
372
361
const half2 * x_dm = (const half2 *) x_qs + txs.qs ;
373
362
const int * y_qs = (const int *) y + 4 ;
@@ -479,7 +468,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
479
468
int * x_qs = (int *) x_tile;
480
469
float * x_df = (float *) (x_qs + WARP_SIZE*2 );
481
470
#else
482
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q5_0);
471
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q5_0, mmq_y );
483
472
int * x_qs = (int *) x_tile;
484
473
float * x_df = (float *) (x_qs + txs.qs );
485
474
#endif // INT8_MMA_AVAILABLE
@@ -548,7 +537,7 @@ template <int mmq_x, int mmq_y, int nwarps>
548
537
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a (
549
538
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
550
539
551
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q5_0);
540
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q5_0, mmq_y );
552
541
const int * x_qs = (const int *) x;
553
542
const float * x_df = (const float *) x_qs + txs.qs ;
554
543
const int * y_qs = (const int *) y + 4 ;
@@ -644,7 +633,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
644
633
int * x_qs = (int *) x_tile;
645
634
half2 * x_dm = (half2 *) (x_qs + 2 *WARP_SIZE);
646
635
#else
647
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q5_1);
636
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q5_1, mmq_y );
648
637
int * x_qs = (int *) x_tile;
649
638
half2 * x_dm = (half2 *) (x_qs + txs.qs );
650
639
#endif // INT8_MMA_AVAILABLE
@@ -711,7 +700,7 @@ template <int mmq_x, int mmq_y, int nwarps>
711
700
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a (
712
701
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
713
702
714
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q5_1);
703
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q5_1, mmq_y );
715
704
const int * x_qs = (const int *) x;
716
705
const half2 * x_dm = (const half2 *) x_qs + txs.qs ;
717
706
const int * y_qs = (const int *) y + 4 ;
@@ -808,7 +797,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
808
797
int * x_qs = (int *) x_tile;
809
798
float * x_df = (float *) (x_tile + WARP_SIZE);
810
799
#else
811
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q8_0);
800
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q8_0, mmq_y );
812
801
int * x_qs = (int *) x_tile;
813
802
float * x_df = (float *) (x_qs + txs.qs );
814
803
#endif // INT8_MMA_AVAILABLE
@@ -858,7 +847,7 @@ template <int mmq_x, int mmq_y, int nwarps>
858
847
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a (
859
848
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
860
849
861
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q8_0);
850
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q8_0, mmq_y );
862
851
const int * x_qs = (const int *) x;
863
852
const float * x_df = (const float *) x_qs + txs.qs ;
864
853
const int * y_qs = (const int *) y + 4 ;
@@ -954,7 +943,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
954
943
int * x_qs = (int *) x_tile;
955
944
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
956
945
#else
957
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q2_K);
946
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q2_K, mmq_y );
958
947
int * x_qs = (int *) x_tile;
959
948
half2 * x_dm = (half2 *) (x_qs + txs.qs );
960
949
#endif // INT8_MMA_AVAILABLE
@@ -1013,7 +1002,7 @@ template <int mmq_x, int mmq_y, int nwarps>
1013
1002
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a (
1014
1003
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1015
1004
1016
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q2_K);
1005
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q2_K, mmq_y );
1017
1006
const int * x_qs = (const int *) x;
1018
1007
const half2 * x_dm = (const half2 *) x_qs + txs.qs ;
1019
1008
const int * y_qs = (const int *) y + 4 ;
@@ -1135,7 +1124,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1135
1124
float * x_df = (float *) (x_qs + WARP_SIZE*2 );
1136
1125
int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K);
1137
1126
#else
1138
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q3_K);
1127
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q3_K, mmq_y );
1139
1128
int * x_qs = (int *) x_tile;
1140
1129
float * x_df = (float *) (x_qs + txs.qs );
1141
1130
int * x_sc = (int *) (x_df + txs.dm );
@@ -1233,7 +1222,7 @@ template <int mmq_x, int mmq_y, int nwarps>
1233
1222
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a (
1234
1223
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1235
1224
1236
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q3_K);
1225
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q3_K, mmq_y );
1237
1226
const int * x_qs = (const int *) x;
1238
1227
const float * x_df = (const float *) x_qs + txs.qs ;
1239
1228
const int * x_sc = (const int *) x_df + txs.dm ;
@@ -1361,7 +1350,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1361
1350
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
1362
1351
int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K);
1363
1352
#else
1364
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q4_K);
1353
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q4_K, mmq_y );
1365
1354
int * x_qs = (int *) x_tile;
1366
1355
half2 * x_dm = (half2 *) (x_qs + txs.qs );
1367
1356
int * x_sc = (int *) (x_dm + txs.dm );
@@ -1437,7 +1426,7 @@ template <int mmq_x, int mmq_y, int nwarps>
1437
1426
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a (
1438
1427
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1439
1428
1440
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q4_K);
1429
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q4_K, mmq_y );
1441
1430
const int * x_qs = (const int *) x;
1442
1431
const half2 * x_dm = (const half2 *) x_qs + txs.qs ;
1443
1432
const int * x_sc = (const int *) x_dm + txs.dm ;
@@ -1578,7 +1567,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1578
1567
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2 );
1579
1568
int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K);
1580
1569
#else
1581
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q5_K);
1570
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q5_K, mmq_y );
1582
1571
int * x_qs = (int *) x_tile;
1583
1572
half2 * x_dm = (half2 *) (x_qs + txs.qs );
1584
1573
int * x_sc = (int *) (x_dm + txs.dm );
@@ -1668,7 +1657,7 @@ template <int mmq_x, int mmq_y, int nwarps>
1668
1657
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a (
1669
1658
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1670
1659
1671
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q5_K);
1660
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q5_K, mmq_y );
1672
1661
const int * x_qs = (const int *) x;
1673
1662
const half2 * x_dm = (const half2 *) x_qs + txs.qs ;
1674
1663
const int * x_sc = (const int *) x_dm + txs.dm ;
@@ -1800,7 +1789,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1800
1789
float * x_df = (float *) (x_qs + WARP_SIZE*2 );
1801
1790
int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
1802
1791
#else
1803
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q6_K);
1792
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q6_K, mmq_y );
1804
1793
int * x_qs = (int *) x_tile;
1805
1794
float * x_df = (float *) (x_qs + txs.qs );
1806
1795
int * x_sc = (int *) (x_df + txs.dm );
@@ -1882,7 +1871,7 @@ template <int mmq_x, int mmq_y, int nwarps>
1882
1871
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a (
1883
1872
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1884
1873
1885
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q6_K);
1874
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q6_K, mmq_y );
1886
1875
const int * x_qs = (const int *) x;
1887
1876
const float * x_df = (const float *) x_qs + txs.qs ;
1888
1877
const int * x_sc = (const int *) x_df + txs.dm ;
@@ -2422,7 +2411,7 @@ struct mmq_args {
2422
2411
2423
2412
template <ggml_type type>
2424
2413
static int mmq_get_shmem (const int mmq_x, const int mmq_y, const int cc) {
2425
- const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host (type, mmq_y);
2414
+ const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (type, mmq_y);
2426
2415
const int mmq_tile_x_k = mmq_get_mma_tile_x_k (type);
2427
2416
const int shmem_x = int8_mma_available (cc) ? mmq_y*mmq_tile_x_k*sizeof (int ) : txs.qs *sizeof (int ) + txs.dm *sizeof (half2) + txs.sc *sizeof (int );
2428
2417
const int shmem_y = mmq_x*sizeof (block_q8_1_mmq);
0 commit comments