@@ -12,13 +12,9 @@ __embed_ggml-common.h__
1212#define GGML_METAL_USE_METAL4
1313
1414#ifdef GGML_METAL_USE_METAL4
15- #include < metal_stdlib>
1615#include < metal_tensor>
1716
1817#include < MetalPerformancePrimitives/MetalPerformancePrimitives.h>
19-
20- using namespace metal ;
21- using namespace mpp ::tensor_ops;
2218#endif
2319
2420using namespace metal ;
@@ -1754,7 +1750,7 @@ kernel void kernel_op_sum_f32(
17541750
17551751 float sumf = 0 ;
17561752
1757- for (int64_t i0 = tpitg.x ; i0 < args.np ; i0 += ntg.x ) {
1753+ for (uint64_t i0 = tpitg.x ; i0 < args.np ; i0 += ntg.x ) {
17581754 sumf += src0[i0];
17591755 }
17601756
@@ -5366,6 +5362,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_at
53665362
53675363#undef FA_TYPES
53685364#undef FA_TYPES_BF
5365+ #undef FA_TYPES_F32
53695366
53705367constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0 )]];
53715368constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1 )]];
@@ -5987,6 +5984,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flas
59875984template [[host_name(" kernel_flash_attn_ext_vec_q8_0_dk576_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8 , dequantize_q8_0_t4, block_q8_0, 8 , dequantize_q8_0_t4, 576 , 512 , 2 >;
59885985
59895986#undef FA_TYPES
5987+ #undef FA_TYPES_F32
59905988
59915989constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0 )]];
59925990constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1 )]];
@@ -8120,9 +8118,9 @@ kernel void kernel_mul_mm(
81208118 auto tA = tensor<threadgroup S0, dextents<int32_t , 2 >, tensor_inline>(sa, dextents<int32_t , 2 >(NK, NR0));
81218119 auto tB = tensor<threadgroup S1, dextents<int32_t , 2 >, tensor_inline>(sb, dextents<int32_t , 2 >(NR1, NK ));
81228120
8123- constexpr auto desc = matmul2d_descriptor (NR1, NR0, NK, false , true , false , matmul2d_descriptor::mode::multiply_accumulate);
8121+ constexpr auto desc = mpp::tensor_ops:: matmul2d_descriptor (NR1, NR0, NK, false , true , false , mpp::tensor_ops:: matmul2d_descriptor::mode::multiply_accumulate);
81248122
8125- matmul2d<desc, execution_simdgroups<4 >> mm;
8123+ mpp::tensor_ops:: matmul2d<desc, execution_simdgroups<4 >> mm;
81268124
81278125 auto cT = mm.get_destination_cooperative_tensor <decltype (tA), decltype (tB), float >();
81288126#endif
@@ -8268,16 +8266,28 @@ kernel void kernel_mul_mm(
82688266 }
82698267 }
82708268
8271- for (short i = 0 ; i < 8 ; ++i) {
8269+ if (FC_mul_mm_bc_inp) {
8270+ for (short i = 0 ; i < 8 ; ++i) {
8271+ const short sx = (tiitg%NL1);
8272+ const short sy = (tiitg/NL1)/8 ;
8273+
8274+ const short lx = i;
8275+ const short ly = (tiitg/NL1)%8 ;
8276+ // const short lx = (tiitg/NL1)%8;
8277+ // const short ly = i;
8278+
8279+ *(sb + NK*(8 *sy + ly) + 8 *sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0 ;
8280+ }
8281+ } else {
82728282 const short sx = (tiitg%NL1);
82738283 const short sy = (tiitg/NL1)/8 ;
82748284
8275- const short lx = i;
8285+ // const short lx = i;
82768286 const short ly = (tiitg/NL1)%8 ;
82778287 // const short lx = (tiitg/NL1)%8;
82788288 // const short ly = i;
82798289
8280- *(sb + NK*(8 *sy + ly) + 8 *sx + lx ) = loop_k + iy + i < args. ne00 ? (S1) *((device T1 *) y + i) : 0 ;
8290+ *(threadgroup S1_2x4 *)( sb + NK*(8 *sy + ly) + 8 *sx) = (S1_2x4)( *((device T1_2x4 *) y)) ;
82818291 }
82828292
82838293 il = (il + 2 < nl) ? il + 2 : il % 2 ;
0 commit comments