Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e85961a
vulkan: use coopmat for flash attention p*v matrix multiplication
0cc4m Jan 20, 2026
7c46dbf
fix P loading issue
0cc4m Jan 21, 2026
b58acd0
fix barrier position
0cc4m Jan 21, 2026
379608f
remove reduction that is no longer needed
0cc4m Jan 21, 2026
4cfc48d
move max thread reduction into loop
0cc4m Jan 21, 2026
b4e96f5
remove osh padding
0cc4m Jan 22, 2026
1ed1f35
add bounds checks and padding
0cc4m Jan 22, 2026
0bbee6d
remove unused code
0cc4m Jan 23, 2026
f435f34
fix shmem sizes, loop duration and accesses
0cc4m Jan 23, 2026
3503133
don't overwrite Qf, add new shared psh buffer instead
0cc4m Jan 23, 2026
1226ba0
add missing bounds checks
0cc4m Jan 23, 2026
44cd07a
use subgroup reductions
0cc4m Jan 24, 2026
3037c75
optimize
0cc4m Jan 24, 2026
4b3b6a6
move bounds check, reduce barriers
0cc4m Jan 24, 2026
1c40f9e
support other Bc values and other subgroup sizes
0cc4m Jan 24, 2026
61745fd
remove D_split
0cc4m Jan 24, 2026
74d3246
replace Of register array with shared memory Ofsh array
0cc4m Jan 24, 2026
e0c414d
parallelize HSV across the rowgroups
0cc4m Jan 24, 2026
ed11a95
go back to Of in registers, not shmem
0cc4m Jan 24, 2026
9384172
vectorize sfsh
0cc4m Jan 24, 2026
a875cc2
don't store entire K tile in shmem
0cc4m Jan 24, 2026
7d75bea
fixes
0cc4m Jan 24, 2026
a0c9f40
load large k tiles to shmem on Nvidia
0cc4m Jan 25, 2026
1e576ba
adapt shared memory host check function to shader changes
0cc4m Jan 25, 2026
ef86d3a
remove Bc 32 case
0cc4m Jan 25, 2026
fcd3a26
remove unused variable
0cc4m Jan 25, 2026
80a4ac0
fix missing mask reduction tmspsh barrier
0cc4m Jan 26, 2026
92051ec
fix mask bounds check
0cc4m Jan 27, 2026
32465b8
fix rowmax f16 under/overflow to inf
0cc4m Jan 28, 2026
7056c66
fix flash_attn_cm2 BLOCK_SIZE preprocessor directives
0cc4m Jan 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3162,17 +3162,31 @@ static void ggml_vk_load_shaders(vk_device& device) {
// For scalar, use 128 (arbitrary)
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
const uint32_t D = (hsk|hsv);
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128);
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);

uint32_t wg_size;
switch (path) {
case FA_COOPMAT2:
wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128);
break;
case FA_COOPMAT1:
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
break;
default:
wg_size = scalar_flash_attention_workgroup_size;
break;
}

// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
const uint32_t D_lsb = D ^ (D & (D-1));
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);

return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
// Nvidia prefers shared memory use to load large tiles of K
// AMD prefers loading K directly from global memory
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #19081 I saw that the coopmat1 path is disabled due to using too much shared memory. We should set k_load_shmem to false for large enough head size (I'm not sure what the exact threshold is).

Even after that, I still see some performance dropoffs I don't understand:

FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,2560,1,1),  v(512,2560,1,1),  m(2560,512,1,1): 47 x 1094.62 us = 51447.4 us (52111.5 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,3072,1,1),  v(512,3072,1,1),  m(3072,512,1,1): 47 x 1324.57 us = 62254.7 us (51678 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,3584,1,1),  v(512,3584,1,1),  m(3584,512,1,1): 47 x 1551.69 us = 72929.3 us (51466.3 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,4096,1,1),  v(512,4096,1,1),  m(4096,512,1,1): 47 x 1776.2 us = 83481.6 us (51383.7 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,4608,1,1),  v(512,4608,1,1),  m(4608,512,1,1): 47 x 1999.99 us = 93999.5 us (51338.6 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,5120,1,1),  v(512,5120,1,1),  m(5120,512,1,1): 47 x 4332.95 us = 203649 us (26329.6 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,5632,1,1),  v(512,5632,1,1),  m(5632,512,1,1): 47 x 4754.85 us = 223478 us (26392.7 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,6144,1,1),  v(512,6144,1,1),  m(6144,512,1,1): 47 x 5216.38 us = 245170 us (26244.6 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,6656,1,1),  v(512,6656,1,1),  m(6656,512,1,1): 47 x 5654.59 us = 265766 us (26228.4 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,7168,1,1),  v(512,7168,1,1),  m(7168,512,1,1): 47 x 6095.29 us = 286479 us (26203.7 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,7680,1,1),  v(512,7680,1,1),  m(7680,512,1,1): 47 x 6531.35 us = 306974 us (26200.9 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,8192,1,1),  v(512,8192,1,1),  m(8192,512,1,1): 47 x 6976.51 us = 327896 us (26164.4 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,8704,1,1),  v(512,8704,1,1),  m(8704,512,1,1): 47 x 26301.7 us = 1.23618e+06 us (7373.85 GFLOPS/s)
FLASH_ATTN_EXT dst(512,20,512,1),  q(576,512,20,1),  k(576,9216,1,1),  v(512,9216,1,1),  m(9216,512,1,1): 47 x 27901.2 us = 1.31136e+06 us (7360.01 GFLOPS/s)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I could also look into reducing Bc to 32 (and row_split to 2) again, but that is not currently working. We can look at that in a follow-up PR.


return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
};

#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
Expand All @@ -3187,15 +3201,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} \
} \
} \
Expand Down Expand Up @@ -8334,41 +8348,49 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;

VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);

return supported;
}

static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc;
const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false);
const uint32_t Br = rows_cols[0];
const uint32_t Bc = rows_cols[1];

const uint32_t MatBr = 16, MatBc = 16;

const uint32_t row_split = Bc / MatBc;

const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);

const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;

const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * acctype;
const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);

const uint32_t qstride = hsk_pad / 4 + 2;
const uint32_t Qf = Br * qstride * f16vec4;

const uint32_t psh_stride = Br / 4 + 2;
const uint32_t Psh = Bc * psh_stride * f16vec4;

const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
const uint32_t sfsh = Bc * sfshstride * acctype;

const uint32_t kshstride = hsk_pad / 4 + 2;
const uint32_t ksh = Bc * kshstride * f16vec4;
const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA;
const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
const uint32_t vsh_stride = MatBc / 4 * row_split;
const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;

const uint32_t slope = Br * sizeof(float);
const uint32_t slope = Br * acctype;

const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;

VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);

return supported;
}
Expand Down Expand Up @@ -8432,7 +8454,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);

const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type);

if (!coopmat_shape_supported || !coopmat_shmem_supported) {
path = FA_SCALAR;
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;

// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
const uint32_t HSK_pad = (HSK + 15) & ~15;
Expand Down Expand Up @@ -74,6 +76,10 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif

#ifndef BLOCK_SIZE
#define BLOCK_SIZE 1
#endif

#if defined(DATA_A_F32)
#undef BLOCK_SIZE
#define BLOCK_SIZE 4
Expand Down
Loading
Loading