Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -1078,12 +1078,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[0]->ne[0] != 112 &&
op->src[0]->ne[0] != 128 &&
op->src[0]->ne[0] != 192 &&
op->src[0]->ne[0] != 256) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek sizes
// TODO: disabled for now, until optmized
op->src[0]->ne[0] != 256 &&
op->src[0]->ne[0] != 576) {
return false;
}
if (op->src[1]->type != op->src[2]->type) {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2520,7 +2520,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {

// simdgroups per threadgroup (a.k.a. warps)
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
int32_t nsg = 4;
int32_t nsg = ne00 >= 512 ? 8 : 4;

const size_t smem = FATTN_SMEM(nsg);

Expand Down
13 changes: 8 additions & 5 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -5552,9 +5552,7 @@ void kernel_flash_attn_ext_impl(

constexpr short NC = (C/8)/NSG;

// note: do not unroll for large heads
#pragma unroll (DK <= 64 ? NC : 1)
for (short cc = 0; cc < NC; ++cc) {
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);

if (DK % 16 != 0) {
Expand All @@ -5575,7 +5573,9 @@ void kernel_flash_attn_ext_impl(
k8x8_t mk[2];
q8x8_t mq[2];

FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
// note: too much unroll can tank the performance for large heads
#pragma unroll (MIN(DK8/2, 4*NSG))
for (short i = 0; i < DK8/2; ++i) {
simdgroup_barrier(mem_flags::mem_none);

simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
Expand Down Expand Up @@ -5749,7 +5749,9 @@ void kernel_flash_attn_ext_impl(
pv += 8*NS20;
}
} else {
FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
constexpr short NC = (C/8)/2;

FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
s8x8_t vs[2];

simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
Expand Down Expand Up @@ -5952,6 +5954,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS
Expand Down
Loading