Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 76 additions & 21 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2858,11 +2858,10 @@ struct vk_fa_tuning_params {
}
};

static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);

static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
GGML_UNUSED(kv_type);

vk_fa_tuning_params result{};
result.path = FA_SCALAR;
Expand Down Expand Up @@ -2914,7 +2913,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,

result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;

if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
result.block_rows /= 2;
}

Expand Down Expand Up @@ -3445,21 +3444,47 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->fp16) {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )

#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
} else
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
}
} else {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)

#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
} else
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
}
}
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
Expand Down Expand Up @@ -8752,7 +8777,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}

static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
GGML_UNUSED(f32acc);
// Needs to be kept up to date on shader changes
const uint32_t wg_size = params.workgroup_size;
Expand All @@ -8761,21 +8786,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con

const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);

const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);

// tmpsh is overestimated slightly
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * float_type_size;

const uint32_t masksh = Bc * (Br + 1) * float_type_size;

const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
uint32_t Qf, kvsh, kblocksh_size;
if (mmq) {
// block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds
const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size;
Qf = Br * (hsk / 32) * block_b_size;

// kvsh uses D = HSV (K goes through kblocksh instead)
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;

// block_a_cache size depends on quant type
uint32_t block_a_size;
switch (kv_type) {
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
default: block_a_size = 0; break;
}
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
} else {
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;

const uint32_t D = std::max(hsk, hsv);
kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;

const uint32_t D = std::max(hsk, hsv);
const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
kblocksh_size = 0;
}

const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;

VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported);

return supported;
}
Expand Down
184 changes: 182 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
#endif

#ifdef MMQ
#extension GL_EXT_integer_dot_product : require
#extension GL_KHR_shader_subgroup_clustered : require

#include "mul_mmq_shmem_types.glsl"
#endif

#extension GL_KHR_shader_subgroup_shuffle : enable
#extension GL_KHR_shader_subgroup_vote : enable

Expand Down Expand Up @@ -41,15 +48,34 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
const uint32_t masksh_stride = Br + 1;
shared FLOAT_TYPE masksh[Bc * masksh_stride];

#ifndef MMQ
const uint32_t qf_stride = HSK / 4 + 1;
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
#else

const uint32_t qf_stride = HSK / 32;
shared block_b_cache Qf[Br * qf_stride];
#endif

#ifndef MMQ
const uint32_t D = HSK > HSV ? HSK : HSV;
#else
const uint32_t D = HSV;
#endif
const uint32_t kvsh_stride = D / 4 + 1;
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];

#ifdef MMQ

shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1];
#endif

shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];

#ifdef MMQ
#include "flash_attn_mmq_funcs.glsl"
#endif

void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
Expand Down Expand Up @@ -82,10 +108,39 @@ void main() {
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N;
#ifndef MMQ
if (is_in_bounds) {
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
#else
const uint buf_ib = r * qf_stride + d / 8;
const uint buf_iqs = d % 8;

FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f);
const FLOAT_TYPEV4 abs_vals = abs(vals);

const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's interesting to fuse this quantization. Seems like it's probably the right tradeoff since Q is just loaded once and is pretty small.

I'm curious, does ggml-cuda use quantization in FA? Just wondering if the quality effects of it are considered OK.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does, but only in the vector kernel. I don't think the tile kernel is using DP4A.

const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8);
const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0);
const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0);
vals = round(vals * qd_inv);

Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));

#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
}
#else // Q4_0, Q4_1, Q5_0, Q5_1
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);

if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
}
#endif
#endif
}
barrier();

Expand Down Expand Up @@ -195,6 +250,7 @@ void main() {

if (SHMEM_STAGING != 0) {
barrier();
#ifndef MMQ
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
Expand All @@ -214,9 +270,29 @@ void main() {
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
#else // MMQ
const uint ints_per_block = 8 / QUANT_R_MMQ;
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
const uint32_t iqs = (idx + tid) % ints_per_block;
const uint32_t ib = (idx + tid) / ints_per_block;
const uint32_t c = ib / (HSK / 32);
const uint32_t block = ib % (HSK / 32);
if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) {
const uint buf_ib = c * qf_stride + block;
if (!KV_bounds_check || j * Bc + c < KV) {
const uint global_ib = (j * Bc + c) * k_stride + block;
k_block_to_shmem(buf_ib, global_ib, iqs, k_offset);
} else {
k_block_to_shmem_zero(buf_ib, iqs);
}
}
}
#endif // MMQ
barrier();
}

#ifndef MMQ
// More d iterations means Q register caching becomes relevant
// Few iterations means the additional registers needed are worse than the speed-up from caching
if (HSK_per_thread / 4 > 4) {
Expand Down Expand Up @@ -275,6 +351,110 @@ void main() {
}
}
}
#else // MMQ
const uint hsk4 = HSK_per_thread / 4;
const uint d_per_step = (hsk4 % 8 == 0) ? 8 :
(hsk4 % 4 == 0) ? 4 :
(hsk4 % 2 == 0) ? 2 : 1;

[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}

[[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) {
int32_t k_quants[d_per_step];
ACC_TYPEV2 k_dm;

if (SHMEM_STAGING != 0) {
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
#if QUANT_AUXF == 1
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
#else
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
#endif

#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
if (d_per_step == 8) {
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
uint vui = kblocksh[buf_ib].qs[d];
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
#endif
}
} else
#endif
{
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
}
}
} else {
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
const uint ib = coord / BLOCK_SIZE;
const uint iqs = (coord % BLOCK_SIZE);

#if QUANT_AUXF == 1
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
#else
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
if (d_per_step == 8) {
#if defined(DATA_A_Q5_0)
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
k_packed.k_data_packed16[k_offset + ib].qh[1]));
#elif defined(DATA_A_Q5_1)
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
#endif
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
#if defined(A_TYPE_PACKED32)
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
#else
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
#endif
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint qh_lo = (qh >> (d * 4)) & 0xF;
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
#endif
}
} else
#endif
{
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
}
}
}

[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8;
const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8;

int32_t acc = 0;
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]);
}

Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x;
if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) {
Sf[r][c] += k_dot_correction(qib, k_dm);
}
}
}
}
#endif // MMQ

[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
// Compute sum across the D_split
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif

#if defined(A_TYPE_PACKED32)
layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32;
layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32;
#endif

#ifndef BLOCK_SIZE
#define BLOCK_SIZE 1
#endif
Expand Down
Loading
Loading