Skip to content

Commit 57fa815

Browse files
committed
cont : better ifdefs
1 parent f8416cf commit 57fa815

File tree

1 file changed

+29
-34
lines changed

1 file changed

+29
-34
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8285,40 +8285,6 @@ kernel void kernel_mul_mm(
82858285

82868286
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
82878287
}
8288-
8289-
il = (il + 2 < nl) ? il + 2 : il % 2;
8290-
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
8291-
8292-
y += NK;
8293-
8294-
// load matrices from threadgroup memory and conduct outer products
8295-
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
8296-
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
8297-
8298-
threadgroup_barrier(mem_flags::mem_threadgroup);
8299-
8300-
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
8301-
simdgroup_barrier(mem_flags::mem_none);
8302-
8303-
FOR_UNROLL (short i = 0; i < 4; i++) {
8304-
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
8305-
}
8306-
8307-
simdgroup_barrier(mem_flags::mem_none);
8308-
8309-
FOR_UNROLL (short i = 0; i < 2; i++) {
8310-
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
8311-
}
8312-
8313-
simdgroup_barrier(mem_flags::mem_none);
8314-
8315-
FOR_UNROLL (short i = 0; i < 8; i++){
8316-
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
8317-
}
8318-
8319-
lsma += 8*64;
8320-
lsmb += 4*64;
8321-
}
83228288
#else
83238289
// load data and store to threadgroup memory
83248290
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
@@ -8378,6 +8344,7 @@ kernel void kernel_mul_mm(
83788344

83798345
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
83808346
}
8347+
#endif
83818348

83828349
il = (il + 2 < nl) ? il + 2 : il % 2;
83838350
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
@@ -8386,6 +8353,34 @@ kernel void kernel_mul_mm(
83868353

83878354
threadgroup_barrier(mem_flags::mem_threadgroup);
83888355

8356+
#ifndef GGML_METAL_HAS_TENSOR
8357+
// load matrices from threadgroup memory and conduct outer products
8358+
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
8359+
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
8360+
8361+
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
8362+
simdgroup_barrier(mem_flags::mem_none);
8363+
8364+
FOR_UNROLL (short i = 0; i < 4; i++) {
8365+
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
8366+
}
8367+
8368+
simdgroup_barrier(mem_flags::mem_none);
8369+
8370+
FOR_UNROLL (short i = 0; i < 2; i++) {
8371+
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
8372+
}
8373+
8374+
simdgroup_barrier(mem_flags::mem_none);
8375+
8376+
FOR_UNROLL (short i = 0; i < 8; i++){
8377+
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
8378+
}
8379+
8380+
lsma += 8*64;
8381+
lsmb += 4*64;
8382+
}
8383+
#else
83898384
auto sA = tA.slice(0, 0);
83908385
auto sB = tB.slice(0, 0);
83918386

0 commit comments

Comments
 (0)