Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,7 @@ void init_indices()
// and breaking the alignment detection.
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
}

// Bias applied to softmax to stay in fp16 range.
// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606
const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;
20 changes: 10 additions & 10 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ void main() {
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
Qf16 *= float16_t(p.scale);

coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);

coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;

Expand Down Expand Up @@ -223,6 +223,8 @@ void main() {

coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);

rowmax += coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(FATTN_KQ_MAX_OFFSET);

coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;

// M = max(rowmax, Mold)
Expand Down Expand Up @@ -265,11 +267,8 @@ void main() {
// resize eM by using smear/reduce
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);

// multiply with fp16 accumulation, then add to O.
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
PV = coopMatMulAdd(P_A, V, PV);

O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
O *= coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag);
O = coopMatMulAdd(P_A, V, O);
}

// If there is split_k, then the split_k resolve shader does the final
Expand Down Expand Up @@ -311,7 +310,7 @@ void main() {
if (sink > Mr[i]) {
ms = exp(Mr[i] - sink);

O[i] *= ms;
O[i] *= float16_t(ms);
} else {
vs = exp(sink - Mr[i]);
}
Expand All @@ -325,15 +324,16 @@ void main() {
Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
}

O = Ldiag*O;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);

O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(Ldiag)*O_D;

#if defined(ACC_TYPE_MAX)
[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
[[unroll]] for (uint i = 0; i < O_D.length(); ++i) { O_D[i] = clamp(O_D[i], D_TYPE(-ACC_TYPE_MAX), D_TYPE(ACC_TYPE_MAX)); }
#endif

uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;

coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) {
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
} else {
Expand Down
Loading