@@ -735,13 +735,35 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
735735
736736 float scale = (1 .0f / sqrt ((float )d_head));
737737
738- bool use_flash_attn = false ;
739- ggml_tensor* kqv = NULL ;
738+ LOG_DEBUG (" attention_ext L_k:%d n_head:%d C:%d d_head:%d" , L_k, n_head, C, d_head);
739+
740+ bool use_flash_attn = true ;
741+ // L_k == n_context AND l_k == n_token ????
742+ use_flash_attn = use_flash_attn && L_k % 256 == 0 ;
743+ use_flash_attn = use_flash_attn && d_head % 64 == 0 ; // why
744+
745+ if (mask != nullptr ) {
746+ // TODO: figure out if we can bend t5 to work too
747+ use_flash_attn = use_flash_attn && mask->ne [2 ] == 1 ;
748+ use_flash_attn = use_flash_attn && mask->ne [3 ] == 1 ;
749+ }
750+
751+ // TODO: more pad or disable for funny tensor shapes
752+
753+ ggml_tensor* kqv = nullptr ;
740754 if (use_flash_attn) {
755+ LOG_DEBUG (" using flash attention" );
756+
757+ k = ggml_cast (ctx, k, GGML_TYPE_F16);
758+
741759 v = ggml_cont (ctx, ggml_permute (ctx, v, 0 , 2 , 1 , 3 )); // [N, n_head, L_k, d_head]
742760 v = ggml_reshape_3d (ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
743- LOG_DEBUG (" k->ne[1] == %d" , k->ne [1 ]);
761+ v = ggml_cast (ctx, v, GGML_TYPE_F16);
762+
744763 kqv = ggml_flash_attn_ext (ctx, q, k, v, mask, scale, 0 , 0 );
764+ ggml_flash_attn_ext_set_prec (kqv, GGML_PREC_F32);
765+
766+ kqv = ggml_view_3d (ctx, kqv, d_head, n_head, L_k, kqv->nb [1 ], kqv->nb [2 ], 0 );
745767 } else {
746768 v = ggml_cont (ctx, ggml_permute (ctx, v, 1 , 2 , 0 , 3 )); // [N, n_head, d_head, L_k]
747769 v = ggml_reshape_3d (ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
@@ -757,10 +779,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
757779 kq = ggml_soft_max_inplace (ctx, kq);
758780
759781 kqv = ggml_mul_mat (ctx, v, kq); // [N * n_head, L_q, d_head]
782+
783+ kqv = ggml_reshape_4d (ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
784+ kqv = ggml_permute (ctx, kqv, 0 , 2 , 1 , 3 ); // [N, L_q, n_head, d_head]
760785 }
761786
762- kqv = ggml_reshape_4d (ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
763- kqv = ggml_cont (ctx, ggml_permute (ctx, kqv, 0 , 2 , 1 , 3 )); // [N, L_q, n_head, d_head]
787+ kqv = ggml_cont (ctx, kqv);
764788 kqv = ggml_reshape_3d (ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
765789
766790 return kqv;
0 commit comments