@@ -710,11 +710,10 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
710710
711711 float scale = (1 .0f / sqrt ((float )d_head));
712712
713- if (flash_attn) {
714- // TODO: remove before merge
715- LOG_DEBUG (" attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d" , L_q, L_k, n_head, C, d_head, N);
716- }
717- // is there anything oddly shaped??
713+ // if (flash_attn) {
714+ // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
715+ // }
716+ // is there anything oddly shaped?? ping Green-Sky if you can trip this assert
718717 GGML_ASSERT (((L_k % 256 == 0 ) && L_q == L_k) || !(L_k % 256 == 0 ));
719718
720719 bool can_use_flash_attn = true ;
@@ -725,17 +724,17 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
725724 can_use_flash_attn = can_use_flash_attn && d_head <= 256 ; // double check
726725
727726 if (mask != nullptr ) {
728- // TODO: figure out if we can bend t5 to work too
727+ // TODO(Green-Sky) : figure out if we can bend t5 to work too
729728 can_use_flash_attn = can_use_flash_attn && mask->ne [2 ] == 1 ;
730729 can_use_flash_attn = can_use_flash_attn && mask->ne [3 ] == 1 ;
731730 }
732731
733- // TODO: more pad or disable for funny tensor shapes
732+ // TODO(Green-Sky) : more pad or disable for funny tensor shapes
734733
735734 ggml_tensor* kqv = nullptr ;
736735 // GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
737736 if (can_use_flash_attn && flash_attn) {
738- LOG_DEBUG (" using flash attention" );
737+ // LOG_DEBUG("using flash attention");
739738 k = ggml_cast (ctx, k, GGML_TYPE_F16);
740739
741740 v = ggml_cont (ctx, ggml_permute (ctx, v, 0 , 2 , 1 , 3 )); // [N, n_head, L_k, d_head]
0 commit comments