@@ -8285,40 +8285,6 @@ kernel void kernel_mul_mm(
82858285
82868286 *(threadgroup S1_2x4 *)(sb + 64 *ib + 8 *ly) = (S1_2x4)(*((device T1_2x4 *) y));
82878287 }
8288-
8289- il = (il + 2 < nl) ? il + 2 : il % 2 ;
8290- x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
8291-
8292- y += NK;
8293-
8294- // load matrices from threadgroup memory and conduct outer products
8295- threadgroup const S0 * lsma = (sa + 4 *64 *(sgitg%2 ));
8296- threadgroup const S1 * lsmb = (sb + 2 *64 *(sgitg/2 ));
8297-
8298- threadgroup_barrier (mem_flags::mem_threadgroup);
8299-
8300- FOR_UNROLL (short ik = 0 ; ik < NK/8 ; ik++) {
8301- simdgroup_barrier (mem_flags::mem_none);
8302-
8303- FOR_UNROLL (short i = 0 ; i < 4 ; i++) {
8304- simdgroup_load (ma[i], lsma + 64 *i, 8 , 0 , false );
8305- }
8306-
8307- simdgroup_barrier (mem_flags::mem_none);
8308-
8309- FOR_UNROLL (short i = 0 ; i < 2 ; i++) {
8310- simdgroup_load (mb[i], lsmb + 64 *i, 8 , 0 , false );
8311- }
8312-
8313- simdgroup_barrier (mem_flags::mem_none);
8314-
8315- FOR_UNROLL (short i = 0 ; i < 8 ; i++){
8316- simdgroup_multiply_accumulate (mc[i], mb[i/4 ], ma[i%4 ], mc[i]);
8317- }
8318-
8319- lsma += 8 *64 ;
8320- lsmb += 4 *64 ;
8321- }
83228288#else
83238289 // load data and store to threadgroup memory
83248290 if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
@@ -8378,6 +8344,7 @@ kernel void kernel_mul_mm(
83788344
83798345 *(threadgroup S1_2x4 *)(sb + NK*(8 *sy + ly) + 8 *sx) = (S1_2x4)(*((device T1_2x4 *) y));
83808346 }
8347+ #endif
83818348
83828349 il = (il + 2 < nl) ? il + 2 : il % 2 ;
83838350 x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
@@ -8386,6 +8353,34 @@ kernel void kernel_mul_mm(
83868353
83878354 threadgroup_barrier (mem_flags::mem_threadgroup);
83888355
8356+ #ifndef GGML_METAL_HAS_TENSOR
8357+ // load matrices from threadgroup memory and conduct outer products
8358+ threadgroup const S0 * lsma = (sa + 4 *64 *(sgitg%2 ));
8359+ threadgroup const S1 * lsmb = (sb + 2 *64 *(sgitg/2 ));
8360+
8361+ FOR_UNROLL (short ik = 0 ; ik < NK/8 ; ik++) {
8362+ simdgroup_barrier (mem_flags::mem_none);
8363+
8364+ FOR_UNROLL (short i = 0 ; i < 4 ; i++) {
8365+ simdgroup_load (ma[i], lsma + 64 *i, 8 , 0 , false );
8366+ }
8367+
8368+ simdgroup_barrier (mem_flags::mem_none);
8369+
8370+ FOR_UNROLL (short i = 0 ; i < 2 ; i++) {
8371+ simdgroup_load (mb[i], lsmb + 64 *i, 8 , 0 , false );
8372+ }
8373+
8374+ simdgroup_barrier (mem_flags::mem_none);
8375+
8376+ FOR_UNROLL (short i = 0 ; i < 8 ; i++){
8377+ simdgroup_multiply_accumulate (mc[i], mb[i/4 ], ma[i%4 ], mc[i]);
8378+ }
8379+
8380+ lsma += 8 *64 ;
8381+ lsmb += 4 *64 ;
8382+ }
8383+ #else
83898384 auto sA = tA.slice (0 , 0 );
83908385 auto sB = tB.slice (0 , 0 );
83918386
0 commit comments