Skip to content

Commit 2e48e97

Browse files
committed
metal : initial Metal4 support
1 parent c009ffe commit 2e48e97

File tree

1 file changed

+110
-10
lines changed

1 file changed

+110
-10
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 110 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@ __embed_ggml-common.h__
99

1010
#include <metal_stdlib>
1111

12+
#define GGML_METAL_USE_METAL4
13+
14+
#ifdef GGML_METAL_USE_METAL4
15+
#include <metal_stdlib>
16+
#include <metal_tensor>
17+
18+
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
19+
20+
using namespace metal;
21+
using namespace mpp::tensor_ops;
22+
#endif
23+
1224
using namespace metal;
1325

1426
#define MAX(x, y) ((x) > (y) ? (x) : (y))
@@ -8054,6 +8066,8 @@ kernel void kernel_mul_mm(
80548066
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
80558067
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
80568068

8069+
threadgroup float * sc = (threadgroup float *)(shmem);
8070+
80578071
constexpr int NR0 = 64;
80588072
constexpr int NR1 = 32;
80598073

@@ -8073,15 +8087,6 @@ kernel void kernel_mul_mm(
80738087
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
80748088
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
80758089

8076-
S0_8x8 ma[4];
8077-
S1_8x8 mb[2];
8078-
8079-
simdgroup_float8x8 mc[8];
8080-
8081-
for (short i = 0; i < 8; i++){
8082-
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
8083-
}
8084-
80858090
const short il0 = (tiitg % NL0);
80868091

80878092
short il = il0;
@@ -8102,7 +8107,28 @@ kernel void kernel_mul_mm(
81028107
+ args.nb11*(r1 + lr1)
81038108
+ args.nb10*iy);
81048109

8110+
#ifndef GGML_METAL_USE_METAL4
8111+
S0_8x8 ma[4];
8112+
S1_8x8 mb[2];
8113+
8114+
simdgroup_float8x8 mc[8];
8115+
8116+
for (short i = 0; i < 8; i++){
8117+
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
8118+
}
8119+
#else
8120+
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
8121+
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
8122+
8123+
constexpr auto desc = matmul2d_descriptor(NR1, NR0, NK, false, true, false, matmul2d_descriptor::mode::multiply_accumulate);
8124+
8125+
matmul2d<desc, execution_simdgroups<4>> mm;
8126+
8127+
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
8128+
#endif
8129+
81058130
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
8131+
#ifndef GGML_METAL_USE_METAL4
81068132
// load data and store to threadgroup memory
81078133
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
81088134
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -8206,26 +8232,100 @@ kernel void kernel_mul_mm(
82068232
lsma += 8*64;
82078233
lsmb += 4*64;
82088234
}
8235+
#else
8236+
// load data and store to threadgroup memory
8237+
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8238+
threadgroup_barrier(mem_flags::mem_threadgroup);
8239+
8240+
// no need for dequantization
8241+
for (short i = 0; i < 16; i++) {
8242+
const short sx = 2*il0 + i/8;
8243+
const short sy = (tiitg/NL0)/8;
8244+
8245+
const short lx = i%8;
8246+
const short ly = (tiitg/NL0)%8;
8247+
//const short lx = (tiitg/NL0)%8;
8248+
//const short ly = i%8;
8249+
8250+
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
8251+
}
8252+
} else {
8253+
S0_4x4 temp_a;
8254+
dequantize_func(x, il, temp_a);
8255+
8256+
threadgroup_barrier(mem_flags::mem_threadgroup);
8257+
8258+
FOR_UNROLL (short i = 0; i < 16; i++) {
8259+
const short sx = 2*il0 + i/8;
8260+
const short sy = (tiitg/NL0)/8;
8261+
8262+
const short lx = i%8;
8263+
const short ly = (tiitg/NL0)%8;
8264+
//const short lx = (tiitg/NL0)%8;
8265+
//const short ly = i%8;
8266+
8267+
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
8268+
}
8269+
}
8270+
8271+
for (short i = 0; i < 8; ++i) {
8272+
const short sx = (tiitg%NL1);
8273+
const short sy = (tiitg/NL1)/8;
8274+
8275+
const short lx = i;
8276+
const short ly = (tiitg/NL1)%8;
8277+
//const short lx = (tiitg/NL1)%8;
8278+
//const short ly = i;
8279+
8280+
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
8281+
}
8282+
8283+
il = (il + 2 < nl) ? il + 2 : il % 2;
8284+
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
8285+
8286+
y += NK;
8287+
8288+
threadgroup_barrier(mem_flags::mem_threadgroup);
8289+
8290+
auto sA = tA.slice(0, 0);
8291+
auto sB = tB.slice(0, 0);
8292+
8293+
mm.run(sB, sA, cT);
8294+
#endif
82098295
}
82108296

82118297
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
82128298
// if no bounds checks on the output are needed, we can directly write to device memory
8299+
#ifdef GGML_METAL_USE_METAL4
8300+
device float * C = (device float *) dst +
8301+
r0 + \
8302+
r1 * args.ne0 + im*args.ne1*args.ne0;
8303+
8304+
auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
8305+
cT.store(tC);
8306+
#else
82138307
device float * C = (device float *) dst +
82148308
(r0 + 32*(sgitg & 1)) + \
82158309
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
82168310

82178311
for (short i = 0; i < 8; i++) {
8218-
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0, 0, false);
8312+
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
82198313
}
8314+
#endif
82208315
} else {
82218316
// block is smaller than 64x32, we should avoid writing data outside of the matrix
82228317
threadgroup_barrier(mem_flags::mem_threadgroup);
82238318

82248319
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
82258320

8321+
#ifdef GGML_METAL_USE_METAL4
8322+
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
8323+
cT.store(tC);
8324+
#else
82268325
for (short i = 0; i < 8; i++) {
82278326
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
82288327
}
8328+
#endif
82298329

82308330
threadgroup_barrier(mem_flags::mem_threadgroup);
82318331

0 commit comments

Comments
 (0)