@@ -9,6 +9,18 @@ __embed_ggml-common.h__
99
1010#include < metal_stdlib>
1111
12+ #define GGML_METAL_USE_METAL4
13+
14+ #ifdef GGML_METAL_USE_METAL4
15+ #include < metal_stdlib>
16+ #include < metal_tensor>
17+
18+ #include < MetalPerformancePrimitives/MetalPerformancePrimitives.h>
19+
20+ using namespace metal ;
21+ using namespace mpp ::tensor_ops;
22+ #endif
23+
1224using namespace metal ;
1325
1426#define MAX (x, y ) ((x) > (y) ? (x) : (y))
@@ -8054,6 +8066,8 @@ kernel void kernel_mul_mm(
80548066 threadgroup S0 * sa = (threadgroup S0 *)(shmem);
80558067 threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096 );
80568068
8069+ threadgroup float * sc = (threadgroup float *)(shmem);
8070+
80578071 constexpr int NR0 = 64 ;
80588072 constexpr int NR1 = 32 ;
80598073
@@ -8073,15 +8087,6 @@ kernel void kernel_mul_mm(
80738087 const short lr0 = ((short )tiitg/NL0) < nr0 ? ((short )tiitg/NL0) : nr0 - 1 ; // 0 .. 63
80748088 const short lr1 = ((short )tiitg/NL1) < nr1 ? ((short )tiitg/NL1) : nr1 - 1 ; // 0 .. 31
80758089
8076- S0_8x8 ma[4 ];
8077- S1_8x8 mb[2 ];
8078-
8079- simdgroup_float8x8 mc[8 ];
8080-
8081- for (short i = 0 ; i < 8 ; i++){
8082- mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
8083- }
8084-
80858090 const short il0 = (tiitg % NL0);
80868091
80878092 short il = il0;
@@ -8102,7 +8107,28 @@ kernel void kernel_mul_mm(
81028107 + args.nb11 *(r1 + lr1)
81038108 + args.nb10 *iy);
81048109
8110+ #ifndef GGML_METAL_USE_METAL4
8111+ S0_8x8 ma[4 ];
8112+ S1_8x8 mb[2 ];
8113+
8114+ simdgroup_float8x8 mc[8 ];
8115+
8116+ for (short i = 0 ; i < 8 ; i++){
8117+ mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
8118+ }
8119+ #else
8120+ auto tA = tensor<threadgroup S0, dextents<int32_t , 2 >, tensor_inline>(sa, dextents<int32_t , 2 >(NK, NR0));
8121+ auto tB = tensor<threadgroup S1, dextents<int32_t , 2 >, tensor_inline>(sb, dextents<int32_t , 2 >(NR1, NK ));
8122+
8123+ constexpr auto desc = matmul2d_descriptor (NR1, NR0, NK, false , true , false , matmul2d_descriptor::mode::multiply_accumulate);
8124+
8125+ matmul2d<desc, execution_simdgroups<4 >> mm;
8126+
8127+ auto cT = mm.get_destination_cooperative_tensor <decltype (tA), decltype (tB), float >();
8128+ #endif
8129+
81058130 for (int loop_k = 0 ; loop_k < args.ne00 ; loop_k += NK) {
8131+ #ifndef GGML_METAL_USE_METAL4
81068132 // load data and store to threadgroup memory
81078133 if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
81088134 threadgroup_barrier (mem_flags::mem_threadgroup);
@@ -8206,26 +8232,100 @@ kernel void kernel_mul_mm(
82068232 lsma += 8 *64 ;
82078233 lsmb += 4 *64 ;
82088234 }
8235+ #else
8236+ // load data and store to threadgroup memory
8237+ if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8238+ threadgroup_barrier (mem_flags::mem_threadgroup);
8239+
8240+ // no need for dequantization
8241+ for (short i = 0 ; i < 16 ; i++) {
8242+ const short sx = 2 *il0 + i/8 ;
8243+ const short sy = (tiitg/NL0)/8 ;
8244+
8245+ const short lx = i%8 ;
8246+ const short ly = (tiitg/NL0)%8 ;
8247+ // const short lx = (tiitg/NL0)%8;
8248+ // const short ly = i%8;
8249+
8250+ *(sa + NK*(8 *sy + ly) + 8 *sx + lx) = loop_k + 16 *il + i < args.ne00 ? *((device T0 *) x + i) : 0 ;
8251+ }
8252+ } else {
8253+ S0_4x4 temp_a;
8254+ dequantize_func (x, il, temp_a);
8255+
8256+ threadgroup_barrier (mem_flags::mem_threadgroup);
8257+
8258+ FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
8259+ const short sx = 2 *il0 + i/8 ;
8260+ const short sy = (tiitg/NL0)/8 ;
8261+
8262+ const short lx = i%8 ;
8263+ const short ly = (tiitg/NL0)%8 ;
8264+ // const short lx = (tiitg/NL0)%8;
8265+ // const short ly = i%8;
8266+
8267+ *(sa + NK*(8 *sy + ly) + 8 *sx + lx) = temp_a[i/4 ][i%4 ];
8268+ }
8269+ }
8270+
8271+ for (short i = 0 ; i < 8 ; ++i) {
8272+ const short sx = (tiitg%NL1);
8273+ const short sy = (tiitg/NL1)/8 ;
8274+
8275+ const short lx = i;
8276+ const short ly = (tiitg/NL1)%8 ;
8277+ // const short lx = (tiitg/NL1)%8;
8278+ // const short ly = i;
8279+
8280+ *(sb + NK*(8 *sy + ly) + 8 *sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0 ;
8281+ }
8282+
8283+ il = (il + 2 < nl) ? il + 2 : il % 2 ;
8284+ x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
8285+
8286+ y += NK;
8287+
8288+ threadgroup_barrier (mem_flags::mem_threadgroup);
8289+
8290+ auto sA = tA.slice (0 , 0 );
8291+ auto sB = tB.slice (0 , 0 );
8292+
8293+ mm.run (sB , sA , cT);
8294+ #endif
82098295 }
82108296
82118297 if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1 )) {
82128298 // if no bounds checks on the output are needed, we can directly write to device memory
8299+ #ifdef GGML_METAL_USE_METAL4
8300+ device float * C = (device float *) dst +
8301+ r0 + \
8302+ r1 * args.ne0 + im*args.ne1 *args.ne0 ;
8303+
8304+ auto tC = tensor<device float , dextents<int32_t , 2 >, tensor_inline>(C, dextents<int32_t , 2 >(args.ne0 , NR1));
8305+ cT.store (tC);
8306+ #else
82138307 device float * C = (device float *) dst +
82148308 (r0 + 32 *(sgitg & 1 )) + \
82158309 (r1 + 16 *(sgitg >> 1 )) * args.ne0 + im*args.ne1 *args.ne0 ;
82168310
82178311 for (short i = 0 ; i < 8 ; i++) {
8218- simdgroup_store (mc[i], C + 8 * (i%4 ) + 8 * args.ne0 * (i/4 ), args.ne0 , 0 , false );
8312+ simdgroup_store (mc[i], C + 8 * (i%4 ) + 8 * args.ne0 * (i/4 ), args.ne0 , 0 , false );
82198313 }
8314+ #endif
82208315 } else {
82218316 // block is smaller than 64x32, we should avoid writing data outside of the matrix
82228317 threadgroup_barrier (mem_flags::mem_threadgroup);
82238318
82248319 threadgroup float * temp_str = ((threadgroup float *) shmem) + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*NR0;
82258320
8321+ #ifdef GGML_METAL_USE_METAL4
8322+ auto tC = tensor<threadgroup float , dextents<int32_t , 2 >, tensor_inline>(sc, dextents<int32_t , 2 >(NR0, NR1));
8323+ cT.store (tC);
8324+ #else
82268325 for (short i = 0 ; i < 8 ; i++) {
82278326 simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *NR0*(i/4 ), NR0, 0 , false );
82288327 }
8328+ #endif
82298329
82308330 threadgroup_barrier (mem_flags::mem_threadgroup);
82318331
0 commit comments