Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max(
template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
const int nbatch_fa) {
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
const int ne11, const int ne12, const int nbatch_fa) {
constexpr int ncols = ncols1*ncols2;

const int bidx0 = blockIdx.x;
Expand All @@ -641,12 +641,14 @@ static __global__ void flash_attn_stream_k_fixup(

const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);

const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.

const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;

const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;

const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
Expand All @@ -655,15 +657,19 @@ static __global__ void flash_attn_stream_k_fixup(
return;
}

const int sequence = kbc0 / (iter_k*iter_j*iter_z);
const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j);
const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;

const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.

if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) {
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
return;
}

dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid;
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;

// Load the partial result that needs a fixup:
float dst_val = 0.0f;
Expand All @@ -682,7 +688,7 @@ static __global__ void flash_attn_stream_k_fixup(
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
Expand Down Expand Up @@ -883,9 +889,10 @@ void launch_fattn(
}
}

const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_z = ((Q->ne[2] + ncols2 - 1) / ncols2);
const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3];
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int gqa_ratio = Q->ne[2] / K->ne[2];
const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];

// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
Expand Down Expand Up @@ -960,7 +967,7 @@ void launch_fattn(

blocks_num.x = ntiles_x;
blocks_num.y = parallel_blocks;
blocks_num.z = ntiles_z*Q->ne[3];
blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];

if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
Expand Down Expand Up @@ -1014,7 +1021,7 @@ void launch_fattn(

flash_attn_stream_k_fixup<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
}
} else if (parallel_blocks > 1) {
const dim3 block_dim_combine(DV, 1, 1);
Expand Down
69 changes: 37 additions & 32 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -933,14 +933,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float logit_softcap,
const uint3 ne01,
const int ne02,
const int gqa_ratio,
const int ne11,
const int stride_Q1,
const int stride_Q2,
const int stride_K,
const int stride_V,
const int stride_mask,
const int jt,
const int zt,
const int zt_gqa,
const int kb0_start,
const int kb0_stop) {
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
Expand Down Expand Up @@ -1023,7 +1024,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int j = jc / ncols2;
const int c = jc % ncols2;

if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt*ncols2 + c < ne02)) {
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
Expand Down Expand Up @@ -1409,7 +1410,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int j_dst = jc_dst / ncols2;
const int c_dst = jc_dst % ncols2;

if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) {
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
continue;
}

Expand Down Expand Up @@ -1448,7 +1449,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
#else
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
scale, slope, logit_softcap, ne01, ne02,
scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
jt, kb0_start, kb0_stop);
NO_DEVICE_CODE;
Expand Down Expand Up @@ -1521,13 +1522,13 @@ static __global__ void flash_attn_ext_f16(

const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);

const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;

// kbc == k block continuous, current index in continuous ijk space.
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;

// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
Expand All @@ -1538,22 +1539,24 @@ static __global__ void flash_attn_ext_f16(
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);

while (kbc < kbc_stop && kb0_stop == iter_k) {
const int sequence = kbc / (iter_k*iter_j*iter_z);
const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;

const int head0 = zt * ncols2;
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.

const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
(const half *) (mask + nb33*(sequence % ne33));
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);

const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;

const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;

if (KV_max) {
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
Expand All @@ -1563,12 +1566,12 @@ static __global__ void flash_attn_ext_f16(
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
} else {
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
}

kbc += iter_k;
Expand All @@ -1582,22 +1585,24 @@ static __global__ void flash_attn_ext_f16(
return;
}

const int sequence = kbc / (iter_k*iter_j*iter_z);
const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;

const int head0 = zt * ncols2;
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.

const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
(const half *) (mask + nb33*(sequence % ne33));
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);

const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;

const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;

if (KV_max) {
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
Expand All @@ -1607,7 +1612,7 @@ static __global__ void flash_attn_ext_f16(
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
Expand Down
4 changes: 2 additions & 2 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8216,8 +8216,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nh : { 4, }) {
for (int nr3 : { 1, 3, }) {
if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
for (int nr2 : { 1, 4, 16 }) {
if (nr2 == 16 && hsk != 128) continue;
for (int nr2 : { 1, 4, 12 }) {
if (nr2 == 12 && hsk != 128) continue;
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
for (int kv : { 113, 512, 1024, }) {
if (nr2 != 1 && kv != 512) continue;
Expand Down
Loading