@@ -5439,8 +5439,8 @@ kernel void kernel_mul_mm(
54395439 ushort tiitg[[thread_index_in_threadgroup]],
54405440 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
54415441
5442- threadgroup T * sa = (threadgroup T *)(shmem);
5443- threadgroup float * sb = (threadgroup float *)(shmem + 4096 );
5442+ threadgroup T * sa = (threadgroup T *)(shmem);
5443+ threadgroup half * sb = (threadgroup half *)(shmem + 4096 );
54445444
54455445 const int r0 = tgpig.y ;
54465446 const int r1 = tgpig.x ;
@@ -5454,12 +5454,12 @@ kernel void kernel_mul_mm(
54545454 const short thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
54555455 const short thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
54565456
5457- simdgroup_T8x8 ma[4 ];
5458- simdgroup_float8x8 mb[2 ];
5459- simdgroup_float8x8 mc[8 ];
5457+ simdgroup_T8x8 ma[4 ];
5458+ simdgroup_half8x8 mb[2 ];
5459+ simdgroup_half8x8 mc[8 ];
54605460
54615461 for (short i = 0 ; i < 8 ; i++){
5462- mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
5462+ mc[i] = make_filled_simdgroup_matrix<half , 8 >(0 .h );
54635463 }
54645464
54655465 short il = (tiitg % THREAD_PER_ROW);
@@ -5493,7 +5493,7 @@ kernel void kernel_mul_mm(
54935493 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
54945494 }
54955495
5496- *(threadgroup float2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y );
5496+ *(threadgroup half2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (half2x4)( *((device float2x4 *)y) );
54975497
54985498 il = (il + 2 < nl) ? il + 2 : il % 2 ;
54995499 x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
@@ -5502,8 +5502,8 @@ kernel void kernel_mul_mm(
55025502 threadgroup_barrier (mem_flags::mem_threadgroup);
55035503
55045504 // load matrices from threadgroup memory and conduct outer products
5505- threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2 ));
5506- threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2 ));
5505+ threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2 ));
5506+ threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2 ));
55075507
55085508 #pragma unroll(4)
55095509 for (short ik = 0 ; ik < BLOCK_SIZE_K/8 ; ik++) {
@@ -5535,15 +5535,22 @@ kernel void kernel_mul_mm(
55355535 (BLOCK_SIZE_N * r1 + 16 *(sgitg >> 1 )) * args.ne0 + im*args.ne1 *args.ne0 ;
55365536
55375537 for (short i = 0 ; i < 8 ; i++) {
5538- simdgroup_store (mc[i], C + 8 * (i%4 ) + 8 * args.ne0 * (i/4 ), args.ne0 );
5538+ // cast to f32
5539+ simdgroup_float8x8 mc_f32 (1 .0f );
5540+ simdgroup_multiply (mc_f32, mc[i], mc_f32);
5541+ simdgroup_store (mc_f32, C + 8 * (i%4 ) + 8 * args.ne0 * (i/4 ), args.ne0 );
5542+ // simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
55395543 }
55405544 } else {
55415545 // block is smaller than 64x32, we should avoid writing data outside of the matrix
55425546 threadgroup_barrier (mem_flags::mem_threadgroup);
55435547 threadgroup float * temp_str = ((threadgroup float *) shmem) \
55445548 + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*BLOCK_SIZE_M;
55455549 for (short i = 0 ; i < 8 ; i++) {
5546- simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *BLOCK_SIZE_M*(i/4 ), BLOCK_SIZE_M);
5550+ simdgroup_float8x8 mc_f32 (1 .0f );
5551+ simdgroup_multiply (mc_f32, mc[i], mc_f32);
5552+ simdgroup_store (mc_f32, temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
5553+ // simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
55475554 }
55485555
55495556 threadgroup_barrier (mem_flags::mem_threadgroup);
0 commit comments