@@ -709,11 +709,10 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
709709
710710 float scale = (1 .0f / sqrt ((float )d_head));
711711
712- if (flash_attn) {
713- // TODO: remove before merge
714- LOG_DEBUG (" attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d" , L_q, L_k, n_head, C, d_head, N);
715- }
716- // is there anything oddly shaped??
712+ // if (flash_attn) {
713+ // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
714+ // }
715+ // is there anything oddly shaped?? ping Green-Sky if you can trip this assert
717716 GGML_ASSERT (((L_k % 256 == 0 ) && L_q == L_k) || !(L_k % 256 == 0 ));
718717
719718 bool can_use_flash_attn = true ;
@@ -724,17 +723,17 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
724723 can_use_flash_attn = can_use_flash_attn && d_head <= 256 ; // double check
725724
726725 if (mask != nullptr ) {
727- // TODO: figure out if we can bend t5 to work too
726+ // TODO(Green-Sky) : figure out if we can bend t5 to work too
728727 can_use_flash_attn = can_use_flash_attn && mask->ne [2 ] == 1 ;
729728 can_use_flash_attn = can_use_flash_attn && mask->ne [3 ] == 1 ;
730729 }
731730
732- // TODO: more pad or disable for funny tensor shapes
731+ // TODO(Green-Sky) : more pad or disable for funny tensor shapes
733732
734733 ggml_tensor* kqv = nullptr ;
735734 // GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
736735 if (can_use_flash_attn && flash_attn) {
737- LOG_DEBUG (" using flash attention" );
736+ // LOG_DEBUG("using flash attention");
738737 k = ggml_cast (ctx, k, GGML_TYPE_F16);
739738
740739 v = ggml_cont (ctx, ggml_permute (ctx, v, 0 , 2 , 1 , 3 )); // [N, n_head, L_k, d_head]
0 commit comments