Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ggml/src/ggml-cpu/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#include "ggml-impl.h"
#include "simd-mappings.h"

#define GGML_FA_TILE_Q 32
#define GGML_FA_TILE_KV 16

#ifdef __cplusplus

#include <utility>
Expand Down Expand Up @@ -84,4 +87,9 @@ static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_pa
return {ir0, ir1};
}

struct ggml_fa_tile_config {
static constexpr size_t Q = GGML_FA_TILE_Q;
static constexpr size_t KV = GGML_FA_TILE_KV;
};

#endif
9 changes: 6 additions & 3 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "vec.h"
#include "ops.h"
#include "ggml.h"
#include "common.h"

#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
Expand Down Expand Up @@ -2866,10 +2867,12 @@ struct ggml_cplan ggml_graph_plan(
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne10 = node->src[1]->ne[0]; // DK
const int64_t ne20 = node->src[2]->ne[0]; // DV
const int64_t DK = node->src[1]->ne[0];
const int64_t DV = node->src[2]->ne[0];

cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
// Tiled flash attention scratch (tile sizes defined in common.h)
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
Expand Down
290 changes: 289 additions & 1 deletion ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8164,6 +8164,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
// online softmax / attention
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf

for (int64_t ic = 0; ic < nek1; ++ic) {
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
Expand Down Expand Up @@ -8271,6 +8272,280 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
}
}

static void ggml_compute_forward_flash_attn_ext_tiled(
const ggml_compute_params * params,
ggml_tensor * dst,
int ir0, int ir1) {
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];

GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)

const int64_t DK = nek0;
const int64_t DV = nev0;
const int64_t N = neq1;

GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N);

// input tensor rows must be contiguous
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));

GGML_ASSERT(neq0 == DK);
GGML_ASSERT(nek0 == DK);
GGML_ASSERT(nev0 == DV);

GGML_ASSERT(neq1 == N);

// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);

GGML_ASSERT(k->type == v->type);
const ggml_type kv_type = k->type;

const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
const size_t kv_type_size = ggml_type_size(kv_type);

// broadcast factors
const int64_t rk2 = neq2/nek2;
const int64_t rk3 = neq3/nek3;

const int64_t rv2 = neq2/nev2;
const int64_t rv3 = neq3/nev3;

float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;

memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));

if (logit_softcap != 0) {
scale /= logit_softcap;
}

const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

int ith = params->ith;

static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;

GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");

int ir = ir0;
while (ir < ir1) {
// q indices for the start of this tile
const int iq3 = ir/(neq2*neq1);
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);

// Number of valid rows in this tile:
// - limited by tile size (Q_TILE_SZ)
// - limited by chunk boundary (ir1 - ir)
// - limited by head boundary (neq1 - iq1) to avoid crossing into next head
const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
GGML_ASSERT(tile_rows > 0);

const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;

float S[Q_TILE_SZ];
float M[Q_TILE_SZ];

for (int i = 0 ; i < Q_TILE_SZ; ++i) {
S[i] = 0.;
M[i] = -INFINITY;
}

// Per-thread scratch layout:
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);

void * Q_q = base;
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile

memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));

// k indices
const int ik3 = iq3 / rk3;
const int ik2 = iq2 / rk2;

// v indices
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;

for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
}
// Zero-pad remaining rows
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
}

for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {

// skip the tile entirely if all the masks are -inf
if (mask) {
bool can_skip = true;
for (int tq = 0; tq < tile_rows; tq++) {
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
can_skip = false;
}
}
}

if (can_skip) {
continue;
}
}

for (int tq = 0; tq < Q_TILE_SZ; tq++) {
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
float s;
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
KQ[tq * KV_TILE_SZ + tk] = s * scale;
}
}

if (logit_softcap != 0.0f) {
ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
}

if (mask) {
ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
}

bool skip[Q_TILE_SZ] = {};

for (int tq = 0; tq < Q_TILE_SZ; tq++) {
float * kq_row = KQ + tq * KV_TILE_SZ;

float tile_max;
ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);

if (tile_max == -INFINITY) {
skip[tq] = true;
continue;
}

const float Mold = M[tq];
const float Mnew = fmaxf(Mold, tile_max);

if (Mnew > Mold) {
const float ms = expf(Mold - Mnew);
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
S[tq] *= ms;
}
M[tq] = Mnew;


S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
}

// Convert V tile to F32 first (if F16), then do MAD
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
// TODO: on ARM, native f16 should be faster
if (kv_type == GGML_TYPE_F16) {
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
}
}
} else {
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
}
}
}
}

// sinks (apply only to valid rows in the tile)
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];

for (int tq = 0; tq < tile_rows; tq++) {
float ms = 1.0f;
float vs = 1.0f;

if (s > M[tq]) {
ms = expf(M[tq] - s);
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
} else {
vs = expf(s - M[tq]);
}

S[tq] = S[tq] * ms + vs;
}
}

for (int tq = 0; tq < tile_rows; tq++) {
// V /= S
const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);

// dst indices
const int i1 = iq1 + tq;
const int i2 = iq2;
const int i3 = iq3;

// permute(0, 2, 1, 3)
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
}

ir += tile_rows;
}
}

static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
Expand Down Expand Up @@ -8343,14 +8618,27 @@ static void ggml_compute_forward_flash_attn_ext_f16(
// The number of elements in each chunk
const int64_t dr = (nr + nchunk - 1) / nchunk;

static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
const bool use_tiled = (q->type == GGML_TYPE_F32 &&
kv_is_f32_or_f16 &&
k->type == v->type &&
nek1 % KV_TILE_SZ == 0 &&
neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size

// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;

while (current_chunk < nchunk) {
const int64_t ir0 = dr * current_chunk;
const int64_t ir1 = MIN(ir0 + dr, nr);

ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
if (use_tiled) {
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
} else {
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
}

current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
Expand Down