@@ -39,6 +39,8 @@ typedef struct {
3939 int8_t qs[QK8_0]; // quants
4040} block_q8_0;
4141
42+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
43+
4244// general-purpose kernel for addition of two tensors
4345// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
4446// cons: not very efficient
@@ -207,54 +209,55 @@ kernel void kernel_soft_max(
207209 lmax = MAX (lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0 .0f ));
208210 }
209211
210- float max = simd_max (lmax);
211- if (tiisg == 0 ) {
212- buf[sgitg] = max;
213- }
212+ // find the max value in the block
213+ float max_val = simd_max (lmax);
214+ if (ntg > N_SIMDWIDTH) {
215+ if (sgitg == 0 ) {
216+ buf[tiisg] = -INFINITY;
217+ }
214218
215- threadgroup_barrier (mem_flags::mem_threadgroup);
219+ threadgroup_barrier (mem_flags::mem_threadgroup);
216220
217- // broadcast, simd group number is ntg / 32
218- for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
219- if (tpitg < i) {
220- buf[tpitg] = MAX (buf[tpitg], buf[tpitg + i]);
221- }
222- }
221+ if (tiisg == 0 ) {
222+ buf[sgitg] = max_val;
223+ }
223224
224- threadgroup_barrier (mem_flags::mem_threadgroup);
225+ threadgroup_barrier (mem_flags::mem_threadgroup);
225226
226- max = buf[0 ];
227+ max_val = buf[tiisg];
228+ max_val = simd_max (max_val);
229+ }
227230
228231 // parallel sum
229232 float lsum = 0 .0f ;
230233 for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
231- const float exp_psrc0 = exp ((psrc0[i00]*scale + (pmask ? pmask[i00] : 0 .0f )) - max );
234+ const float exp_psrc0 = exp ((psrc0[i00]*scale + (pmask ? pmask[i00] : 0 .0f )) - max_val );
232235 lsum += exp_psrc0;
233- // Remember the result of exp here. exp is expensive, so we really do not
234- // wish to compute it twice.
235236 pdst[i00] = exp_psrc0;
236237 }
237238
238239 float sum = simd_sum (lsum);
239- if (tiisg == 0 ) {
240- buf[sgitg] = sum;
241- }
240+ if (ntg > N_SIMDWIDTH) {
241+ if (sgitg == 0 ) {
242+ buf[tiisg] = 0 .0f ;
243+ }
242244
243- threadgroup_barrier (mem_flags::mem_threadgroup);
245+ threadgroup_barrier (mem_flags::mem_threadgroup);
244246
245- // broadcast, simd group number is ntg / 32
246- for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
247- if (tpitg < i) {
248- buf[tpitg] += buf[tpitg + i];
249- }
250- }
247+ if (tiisg == 0 ) {
248+ buf[sgitg] = sum;
249+ }
251250
252- threadgroup_barrier (mem_flags::mem_threadgroup);
251+ threadgroup_barrier (mem_flags::mem_threadgroup);
252+
253+ sum = buf[tiisg];
254+ sum = simd_sum (sum);
255+ }
253256
254- sum = buf[ 0 ] ;
257+ const float inv_sum = 1 . 0f /sum ;
255258
256259 for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
257- pdst[i00] /= sum ;
260+ pdst[i00] *= inv_sum ;
258261 }
259262}
260263
@@ -288,53 +291,56 @@ kernel void kernel_soft_max_4(
288291 }
289292
290293 const float lmax = MAX (MAX (lmax4[0 ], lmax4[1 ]), MAX (lmax4[2 ], lmax4[3 ]));
291- float max = simd_max (lmax);
292- if (tiisg == 0 ) {
293- buf[sgitg] = max;
294- }
295294
296- threadgroup_barrier (mem_flags::mem_threadgroup);
295+ float max_val = simd_max (lmax);
296+ if (ntg > N_SIMDWIDTH) {
297+ if (sgitg == 0 ) {
298+ buf[tiisg] = -INFINITY;
299+ }
297300
298- // broadcast, simd group number is ntg / 32
299- for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
300- if (tpitg < i) {
301- buf[tpitg] = MAX (buf[tpitg], buf[tpitg + i]);
302- }
303- }
301+ threadgroup_barrier (mem_flags::mem_threadgroup);
304302
305- threadgroup_barrier (mem_flags::mem_threadgroup);
303+ if (tiisg == 0 ) {
304+ buf[sgitg] = max_val;
305+ }
306306
307- max = buf[0 ];
307+ threadgroup_barrier (mem_flags::mem_threadgroup);
308+
309+ max_val = buf[tiisg];
310+ max_val = simd_max (max_val);
311+ }
308312
309313 // parallel sum
310314 float4 lsum4 = 0 .0f ;
311315 for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
312- const float4 exp_psrc4 = exp ((psrc4[i00]*scale + (pmask ? pmask[i00] : 0 .0f )) - max );
316+ const float4 exp_psrc4 = exp ((psrc4[i00]*scale + (pmask ? pmask[i00] : 0 .0f )) - max_val );
313317 lsum4 += exp_psrc4;
314318 pdst4[i00] = exp_psrc4;
315319 }
316320
317321 const float lsum = lsum4[0 ] + lsum4[1 ] + lsum4[2 ] + lsum4[3 ];
318322 float sum = simd_sum (lsum);
319- if (tiisg == 0 ) {
320- buf[sgitg] = sum;
321- }
323+ if (ntg > N_SIMDWIDTH) {
324+ if (sgitg == 0 ) {
325+ buf[tiisg] = 0 .0f ;
326+ }
322327
323- threadgroup_barrier (mem_flags::mem_threadgroup);
328+ threadgroup_barrier (mem_flags::mem_threadgroup);
324329
325- // broadcast, simd group number is ntg / 32
326- for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
327- if (tpitg < i) {
328- buf[tpitg] += buf[tpitg + i];
329- }
330- }
330+ if (tiisg == 0 ) {
331+ buf[sgitg] = sum;
332+ }
331333
332- threadgroup_barrier (mem_flags::mem_threadgroup);
334+ threadgroup_barrier (mem_flags::mem_threadgroup);
335+
336+ sum = buf[tiisg];
337+ sum = simd_sum (sum);
338+ }
333339
334- sum = buf[ 0 ] ;
340+ const float inv_sum = 1 . 0f /sum ;
335341
336342 for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
337- pdst4[i00] /= sum ;
343+ pdst4[i00] *= inv_sum ;
338344 }
339345}
340346
@@ -582,7 +588,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
582588// putting them in the kernel cause a significant performance penalty
583589#define N_DST 4 // each SIMD group works on 4 rows
584590#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
585- #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
586591// Note: This is a template, but strictly speaking it only applies to
587592// quantizations where the block size is 32. It also does not
588593// giard against the number of rows not being divisible by
0 commit comments