Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nqpsg,
int32_t nsg,
int32_t nwg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
Expand All @@ -1450,11 +1451,17 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];

snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
char nq_suffix[8] = {0};
if (nqpsg > 1) {
snprintf(nq_suffix, sizeof(nq_suffix), "_q%d", nqpsg);
}

snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d%s",
"flash_attn_ext_vec",
ggml_type_name(op->src[1]->type),
dk,
dv);
dv,
nq_suffix);

snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
base,
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_att
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nqpsg,
int32_t nsg,
int32_t nwg);

Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@
#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32

// minimum ne11 (KV length) for the dk=128 Q=2 vec specialization;
// below this the K/V reuse savings do not offset the extra register pressure
#define OP_FLASH_ATTN_EXT_VEC_Q2_DK128_MIN_KV 4096
Comment on lines +112 to +114
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this threshold necessary? What is the impact of not having it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The opt is more beneficial for longer contexts. From my testing, below ~2048 tokens there can be some performance regression/noise depending on the workload, so I set the threshold to 4096 to stay on the safe side.


#define OP_UNARY_NUM_SCALE 10
#define OP_UNARY_NUM_FILL 11
#define OP_UNARY_NUM_CLAMP 12
Expand Down
14 changes: 11 additions & 3 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <cassert>
#include <algorithm>
#include <cstdlib>
#include <limits>
#include <cmath>

Expand Down Expand Up @@ -2871,7 +2872,14 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
#undef FATTN_SMEM
} else {
// half4x4 kernel
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
// Amortize one K/V tile read across two query rows.
static const bool disable_q2 = std::getenv("GGML_METAL_FA_DISABLE_Q2") != nullptr;
const bool can_q2 = !disable_q2 &&
op->src[1]->type == GGML_TYPE_F16 && ne01 >= 2 &&
( (ne00 == 256 && ne20 == 256) ||
(ne00 == 128 && ne20 == 128 && ne11 >= OP_FLASH_ATTN_EXT_VEC_Q2_DK128_MIN_KV) );

const int nqptg = can_q2 ? 2 : OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
const int nhptg = 1; // heads per threadgroup

Expand Down Expand Up @@ -2935,7 +2943,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
// ne20*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg)*nqptg)*(sizeof(float)/2), 16))

int64_t nsg = 1;

Expand Down Expand Up @@ -2990,7 +2998,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap,
};

auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nqptg, nsg, nwg);

GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));

Expand Down
Loading
Loading