@@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516516 nullptr ;
517517}
518518
519- // The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
520- #ifdef __clang__
521- #pragma clang diagnostic push
522- #pragma clang diagnostic ignored "-Wpass-failed"
523- #endif // __clang__
524-
525- template <int D, int ncols, int KQ_stride> // D == head size
526- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
519+ template <int D, int ncols1, int ncols2, int KQ_stride> // D == head size
527520__launch_bounds__ (D, 1 )
528- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
529521static __global__ void flash_attn_stream_k_fixup(
530522 float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
531- const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim .x *(2 *2 *ncols);
532-
533- const int iter_k = ne11 / KQ_stride;
534- const int iter_j = (ne01 + (ncols - 1 )) / ncols;
523+ constexpr int ncols = ncols1*ncols2;
535524
536525 const int bidx0 = blockIdx .x ;
526+ const int j = blockIdx .y ;
527+ const int c = blockIdx .z ;
528+ const int jc = j*ncols2 + c;
529+ const int tid = threadIdx .x ;
530+
531+ const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim .x *(2 *2 *ncols);
532+
533+ const int iter_k = ne11 / FATTN_KQ_STRIDE;
534+ const int iter_j = (ne01 + (ncols1 - 1 )) / ncols1;
537535
538- const int kbc0 = (bidx0 + 0 )*iter_k*iter_j*ne02 / gridDim .x ;
539- const int kbc0_stop = (bidx0 + 1 )*iter_k*iter_j*ne02 / gridDim .x ;
536+ const int kbc0 = (bidx0 + 0 )*iter_k*iter_j*( ne02/ncols2) / gridDim .x ;
537+ const int kbc0_stop = (bidx0 + 1 )*iter_k*iter_j*( ne02/ncols2) / gridDim .x ;
540538
541539 const bool did_not_have_any_data = kbc0 == kbc0_stop;
542540 const bool wrote_beginning_of_tile = kbc0 % iter_k == 0 ;
@@ -548,59 +546,53 @@ static __global__ void flash_attn_stream_k_fixup(
548546 const int channel = kbc0 / (iter_k*iter_j);
549547 const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
550548
551- dst += jt*ncols*ne02*D + channel*D;
549+ if (jt*ncols1 + j >= ne01) {
550+ return ;
551+ }
552552
553- // Load the partial result that needs a fixup:
554- float dst_val[ncols] = {0 .0f };
555- float max_val[ncols] = {0 .0f };
556- float rowsum[ncols] = {0 .0f };
557- #pragma unroll
558- for (int j = 0 ; j < ncols; ++j) {
559- if (jt*ncols + j >= ne01) {
560- break ;
561- }
562- dst_val[j] = dst[j*ne02*D + threadIdx .x ];
553+ dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
563554
564- const float2 tmp = dst_fixup[bidx0*ncols + j];
565- max_val[j] = tmp.x ;
566- rowsum[j] = tmp.y ;
555+ // Load the partial result that needs a fixup:
556+ float dst_val = 0 .0f ;
557+ float max_val = 0 .0f ;
558+ float rowsum = 0 .0f ;
559+ {
560+ dst_val = *dst;
561+
562+ const float2 tmp = dst_fixup[bidx0*ncols + jc];
563+ max_val = tmp.x ;
564+ rowsum = tmp.y ;
567565 }
568566
569567 // Iterate over previous blocks and compute the combined results.
570568 // All CUDA blocks that get here must have a previous block that needs a fixup.
571569 int bidx = bidx0 - 1 ;
572570 int kbc_stop = kbc0;
573571 while (true ) {
574- const int kbc = bidx*iter_k*iter_j*ne02 / gridDim .x ;
572+ const int kbc = bidx*iter_k*iter_j*( ne02/ncols2) / gridDim .x ;
575573 if (kbc == kbc_stop) { // Did not have any data.
576574 bidx--;
577575 kbc_stop = kbc;
578576 continue ;
579577 }
580578
581- #pragma unroll
582- for (int j = 0 ; j < ncols; ++j) {
583- if (jt*ncols + j >= ne01) {
584- break ;
585- }
586- const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx .x ];
579+ const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
587580
588- const float2 tmp = dst_fixup[(gridDim .x + bidx)*ncols + j ];
581+ const float2 tmp = dst_fixup[(gridDim .x + bidx)*ncols + jc ];
589582
590- // Scale the current and new value accumulators depending on the max. values.
591- const float max_val_new = fmaxf (max_val[j] , tmp.x );
583+ // Scale the current and new value accumulators depending on the max. values.
584+ const float max_val_new = fmaxf (max_val, tmp.x );
592585
593- const float diff_val = max_val[j] - max_val_new;
594- const float diff_add = tmp.x - max_val_new;
586+ const float diff_val = max_val - max_val_new;
587+ const float diff_add = tmp.x - max_val_new;
595588
596- const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_val) : 0 .0f ;
597- const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_add) : 0 .0f ;
589+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_val) : 0 .0f ;
590+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_add) : 0 .0f ;
598591
599- dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
600- rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y ;
592+ dst_val = scale_val*dst_val + scale_add*dst_add;
593+ rowsum = scale_val*rowsum + scale_add*tmp.y ;
601594
602- max_val[j] = max_val_new;
603- }
595+ max_val = max_val_new;
604596
605597 // If this block started in a previous tile we are done and don't need to combine additional partial results.
606598 if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
@@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
611603 }
612604
613605 // Write back final result:
614- #pragma unroll
615- for (int j = 0 ; j < ncols; ++j) {
616- if (jt*ncols + j >= ne01) {
617- return ;
618- }
619- dst[j*ne02*D + threadIdx .x ] = dst_val[j] / rowsum[j];
620- }
606+ *dst = dst_val / rowsum;
621607}
622608
623- #ifdef __clang__
624- #pragma clang diagnostic pop
625- #endif // __clang__
626-
627609template <int D, int parallel_blocks> // D == head size
628610#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
629611__launch_bounds__ (D, 1 )
@@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
690672}
691673
692674// parallel_blocks == 0 is stream-k decomposition
693- template <int D, int cols_per_block , int parallel_blocks, int KQ_stride>
675+ template <int D, int ncols1, int ncols2 , int parallel_blocks, int KQ_stride>
694676void launch_fattn (
695677 ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
696678 const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
697679) {
680+ constexpr int ncols = ncols1 * ncols2;
681+
698682 const ggml_tensor * Q = dst->src [0 ];
699683 const ggml_tensor * K = dst->src [1 ];
700684 const ggml_tensor * V = dst->src [2 ];
@@ -763,25 +747,26 @@ void launch_fattn(
763747 nb23 = nb23*bs*sizeof (half)/ts;
764748 }
765749
766- const int ntiles_x = ((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block );
767- const int ntiles_total = ntiles_x* Q->ne [2 ]* Q->ne [3 ];
750+ const int ntiles_x = ((Q->ne [1 ] + ncols1 - 1 ) / ncols1 );
751+ const int ntiles_total = ntiles_x * ( Q->ne [2 ] / ncols2) * Q->ne [3 ];
768752
769753 const dim3 block_dim (WARP_SIZE, nwarps, 1 );
770754 dim3 blocks_num;
771755 if (parallel_blocks == 0 ) {
772756 // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
773- const int tiles_nwaves = (ntiles_total + 2 *nsm - 1 ) / (2 *nsm);
774- const int tiles_efficiency_percent = 100 * ntiles_total / (2 *nsm*tiles_nwaves);
757+ const int max_blocks = 2 *nsm;
758+ const int tiles_nwaves = (ntiles_total + max_blocks - 1 ) / max_blocks;
759+ const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
775760
776- const int nblocks_stream_k = 2 *nsm ;
761+ const int nblocks_stream_k = max_blocks ;
777762
778- const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE ;
763+ const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75 ;
779764
780765 blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
781766 blocks_num.y = 1 ;
782767 blocks_num.z = 1 ;
783768
784- dst_tmp_meta.alloc (blocks_num.x *cols_per_block * (2 *2 + D) * sizeof (float ));
769+ dst_tmp_meta.alloc (blocks_num.x *ncols * (2 *2 + D) * sizeof (float ));
785770 } else {
786771 blocks_num.x = parallel_blocks*ntiles_x;
787772 blocks_num.y = Q->ne [2 ];
@@ -793,7 +778,6 @@ void launch_fattn(
793778 }
794779 }
795780
796-
797781 float scale = 1 .0f ;
798782 float max_bias = 0 .0f ;
799783 float logit_softcap = 0 .0f ;
@@ -832,9 +816,9 @@ void launch_fattn(
832816 if constexpr (parallel_blocks == 0 ) {
833817 if (ntiles_total % blocks_num.x != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
834818 const dim3 block_dim_combine (D, 1 , 1 );
835- const dim3 blocks_num_combine = blocks_num;
819+ const dim3 blocks_num_combine = { blocks_num. x , ncols1, ncols2} ;
836820
837- flash_attn_stream_k_fixup<D, cols_per_block , KQ_stride>
821+ flash_attn_stream_k_fixup<D, ncols1, ncols2 , KQ_stride>
838822 <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
839823 ((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], K->ne [1 ]);
840824 }
0 commit comments