@@ -2286,14 +2286,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
22862286    };
22872287
22882288#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
2289-         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2290-         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2291-         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2292-         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2293-         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true),   1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2294-         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true),   fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1],  true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2295-         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true),   1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2296-         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true),   fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1],  true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2289+         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2290+         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2291+         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 6 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2292+         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 6 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2293+         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true),   1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2294+         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true),   fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1],  true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2295+         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 6 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true),   1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
2296+         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 6 , sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true),   fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1],  true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
22972297
22982298#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
22992299        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
@@ -2910,7 +2910,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
29102910    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4],   "get_rows_mxfp4_f32",   get_rows_mxfp4_f32_len,   get_rows_mxfp4_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
29112911
29122912    ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2913-     ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4  * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
2913+     ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5  * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
29142914    ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
29152915
29162916    for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -6507,11 +6507,14 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
65076507    return supported;
65086508}
65096509
6510- static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
6510+ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks,  ggml_tensor * dst, bool dryrun = false) {
65116511    VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
65126512    std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
65136513    std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
65146514    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
6515+     if (sinks) {
6516+         std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
6517+     }
65156518    std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
65166519
65176520    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
@@ -6710,10 +6713,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
67106713    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
67116714    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
67126715
6713-     vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
6714-     size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
6716+     vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr ;
6717+     size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0 ;
67156718
6716-     bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
6719+     bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false ;
67176720
67186721    if (ctx->device->uma) {
67196722        ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
@@ -6728,6 +6731,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
67286731            ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
67296732            M_uma = d_M != nullptr;
67306733        }
6734+         if (sinks) {
6735+             ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset);
6736+             S_uma = d_S != nullptr;
6737+         }
67316738    }
67326739
67336740
@@ -6763,7 +6770,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
67636770        }
67646771    }
67656772
6766-     uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
6773+     if (!S_uma) {
6774+         d_S = d_Q;
6775+         s_buf_offset = q_buf_offset;
6776+         if (sinks) {
6777+             ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context;
6778+             d_S = s_buf_ctx->dev_buffer;
6779+             s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs;
6780+         }
6781+     }
6782+ 
6783+     uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
67676784
67686785    const vk_flash_attn_push_constants pc = { N, KV,
67696786                                              (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
@@ -6787,6 +6804,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
67876804                                        vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
67886805                                        vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
67896806                                        vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
6807+                                         vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
67906808                                        vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
67916809                                    },
67926810                                    // We only use split_k when group query attention is enabled, which means
@@ -6796,10 +6814,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
67966814                                    pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
67976815
67986816        ggml_vk_sync_buffers(subctx);
6799-         const std::array<uint32_t, 4 > pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
6817+         const std::array<uint32_t, 5 > pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr)  };
68006818        ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
68016819                                    {
68026820                                        vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6821+                                         vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
68036822                                        vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
68046823                                    },
68056824                                    pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
@@ -6810,6 +6829,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
68106829                                        vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
68116830                                        vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
68126831                                        vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
6832+                                         vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
68136833                                        vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
68146834                                    },
68156835                                    pc, { workgroups_x, workgroups_y, workgroups_z });
@@ -9874,7 +9894,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
98749894        break;
98759895
98769896    case GGML_OP_FLASH_ATTN_EXT:
9877-         ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
9897+         ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node , dryrun);
98789898
98799899        break;
98809900
@@ -10951,8 +10971,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1095110971                if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
1095210972                    return false;
1095310973                }
10954-                 // TODO: support attention sinks [TAG_ATTN_SINKS]
10955-                 if (op->src[4]) {
10974+                 if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
1095610975                    return false;
1095710976                }
1095810977                if (op->src[0]->type != GGML_TYPE_F32) {
@@ -11547,6 +11566,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1154711566    if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
1154811567        const float * params = (const float *)tensor->op_params;
1154911568        tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
11569+         if (src_clone[4]) {
11570+             ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);
11571+         }
1155011572    } else if (tensor->op == GGML_OP_MUL_MAT) {
1155111573        tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
1155211574    } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
0 commit comments