Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@
#define FC_COUNT_EQUAL 1000

// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8
#define OP_FLASH_ATTN_EXT_NQPSG 8
#define OP_FLASH_ATTN_EXT_NCPSG 64

#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32

// kernel argument structs
Expand Down
38 changes: 11 additions & 27 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2295,7 +2295,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
// return res;
//}

const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;

const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
Expand Down Expand Up @@ -2411,7 +2411,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {

if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
// half8x8 kernel
const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup

GGML_ASSERT(nqptg <= 32);
Expand Down Expand Up @@ -2578,9 +2578,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
#undef FATTN_SMEM
} else {
// half4x4 kernel
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
const int nkpsg = 1*ncpsg;
const int nhptg = 1; // heads per threadgroup

GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
Expand Down Expand Up @@ -2632,35 +2632,19 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx);
}

// note: for simplicity assume the K is larger or equal than V
GGML_ASSERT(ne10 >= ne20);

// ne00 + 2*ncpsg*(nsg)
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
//
// ne20*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))

int64_t nsgmax = 2;
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
// avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
if (smem > props_dev->max_theadgroup_memory_size/2) {
break;
}
nsgmax *= 2;
}
nsgmax /= 2;

// simdgroups per threadgroup (a.k.a. warps)
//const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))

int64_t nsg = 1;
while (nsg <= nsgt) {
nsg *= 2;
}
nsg /= 2;

// workgroups
// each workgroup handles nsg*nkpsg cache values
Expand All @@ -2673,7 +2657,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
} else {
nwg = 32;
nsg = 1;
while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
nsg *= 2;
}
}
Expand Down Expand Up @@ -2739,7 +2723,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {

ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);

ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
} else {
// sanity checks
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
Expand All @@ -2752,7 +2736,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);

ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);

// sync the 2 kernels
ggml_metal_op_concurrency_reset(ctx);
Expand Down
86 changes: 21 additions & 65 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -5931,7 +5931,7 @@ template<
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
short DK, // K head size
short DV, // V head size
short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
kernel void kernel_flash_attn_ext(
constant ggml_metal_kargs_flash_attn_ext & args,
Expand Down Expand Up @@ -6141,11 +6141,10 @@ template<
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
short DK, // K head size
short DV, // V head size
short NE, // head elements per thread
short Q, // queries per threadgroup
short C, // cache items per threadgroup
short NSG> // number of simd groups
void kernel_flash_attn_ext_vec_impl(
short NE = 4, // head elements per thread
short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup
short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
kernel void kernel_flash_attn_ext_vec(
constant ggml_metal_kargs_flash_attn_ext_vec & args,
device const char * q,
device const char * k,
Expand All @@ -6162,6 +6161,7 @@ void kernel_flash_attn_ext_vec_impl(
static_assert(DV % 32 == 0, "DV must be divisible by 32");

#define NWG (FC_flash_attn_ext_vec_nwg)
#define NSG (FC_flash_attn_ext_vec_nsg)

#define NS10 (FC_flash_attn_ext_vec_ns10)
#define NS20 (FC_flash_attn_ext_vec_ns20)
Expand Down Expand Up @@ -6190,12 +6190,12 @@ void kernel_flash_attn_ext_vec_impl(

const short T = PK + NSG*SH; // shared memory size per query in (half)

//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results

// store the result for all queries in shared memory (the O matrix from the paper)
so4 += tiisg;
Expand All @@ -6213,11 +6213,13 @@ void kernel_flash_attn_ext_vec_impl(
// load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q);

for (short i = tiisg; i < PK4; i += NW) {
if (iq1 < args.ne01 && i < DK4) {
sq4[i] = (q4_t) q4[i];
} else {
sq4[i] = (q4_t) 0.0f;
if (iq1 < args.ne01) {
for (short i = tiisg; i < PK4; i += NW) {
if (i < DK4) {
sq4[i] = (q4_t) q4[i];
} else {
sq4[i] = (q4_t) 0.0f;
}
}
}

Expand Down Expand Up @@ -6295,7 +6297,7 @@ void kernel_flash_attn_ext_vec_impl(
}

// skip -INF blocks
if (simd_max(sm[tiisg]) == -INFINITY) {
if (simd_max(sm[tiisg]) <= -MAXHALF) {
continue;
}

Expand Down Expand Up @@ -6569,57 +6571,11 @@ void kernel_flash_attn_ext_vec_impl(
}

#undef NWG
#undef NSG
#undef NS10
#undef NS20
}

template<
typename q4_t, // query types in shared memory
typename k4_t, // key types in shared memory
typename v4_t, // value types in shared memory
typename qk_t, // Q*K types
typename s_t, // soft-max types
typename s4_t,
typename o4_t, // attention accumulation types
typename kd4_t, // key type in device memory
short nl_k,
void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
typename vd4_t, // value type in device memory
short nl_v,
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
short DK, // K head size
short DV, // V head size
short NE = 4, // head elements per thread
short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
kernel void kernel_flash_attn_ext_vec(
constant ggml_metal_kargs_flash_attn_ext_vec & args,
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
switch (FC_flash_attn_ext_vec_nsg) {
// note: disabled cases to reduce library load time
case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
case 2: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 2>(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 4>(FWD_ARGS); break;
//case 8: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 8>(FWD_ARGS); break;
//case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break;
//case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS
}

// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
//
Expand Down