@@ -540,10 +540,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
540540                KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf (KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
541541            }
542542
543-             KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1  || mask) && (!oob_check || i_KQ < k_VKQ_sup) ?
544-                 slope*__half2float (mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0 .0f ;
543+             if  (!oob_check || i_KQ < k_VKQ_sup) {
544+                 KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1  || mask) ?
545+                     slope*__half2float (mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0 .0f ;
545546
546-             KQ_max_new[jc0] = fmaxf (KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
547+                 KQ_max_new[jc0] = fmaxf (KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
548+             }
547549        }
548550
549551        KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
@@ -581,10 +583,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
581583            float  KQ_sum_add = 0 .0f ;
582584#pragma  unroll
583585            for  (int  i0 = 0 ; i0 < nbatch_fa; i0 += np*warp_size) {
584-                 const  float  val = expf (KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]);
585-                 if  (!oob_check || i0 + (threadIdx .y  % np)*warp_size + threadIdx .x  < k_VKQ_sup) {
586-                     KQ_sum_add += val;
587-                 }
586+                 const  float  val = !oob_check || i0 + (threadIdx .y  % np)*warp_size + threadIdx .x  < k_VKQ_sup ?
587+                     expf (KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0 .0f ;
588+                 KQ_sum_add += val;
588589                tmp[i0/(np*warp_size)][jc1] = val;
589590            }
590591            KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
@@ -975,26 +976,6 @@ static __global__ void flash_attn_tile(
975976        }
976977    }
977978
978-     if  (gridDim .y  == 1 ) {
979- #pragma  unroll
980-         for  (int  jc0 = 0 ; jc0 < cpw; ++jc0) {
981- #ifdef  FAST_FP16_AVAILABLE
982-             const  half2 KQ_sum_jc_inv = make_half2 (1 .0f /KQ_sum[jc0], 1 .0f /KQ_sum[jc0]);
983- #pragma  unroll
984-             for  (int  i = 0 ; i < (DVp/2 )/warp_size; ++i) {
985-                 VKQ[jc0*((DVp/2 )/warp_size) + i] *= KQ_sum_jc_inv;
986-             }
987- #else 
988-             const  float  KQ_sum_jc_inv = 1 .0f /KQ_sum[jc0];
989- #pragma  unroll
990-             for  (int  i = 0 ; i < (DVp/2 )/warp_size; ++i) {
991-                 VKQ[jc0*((DVp/2 )/warp_size) + i].x  *= KQ_sum_jc_inv;
992-                 VKQ[jc0*((DVp/2 )/warp_size) + i].y  *= KQ_sum_jc_inv;
993-             }
994- #endif  //  FAST_FP16_AVAILABLE
995-         }
996-     }
997- 
998979    //  Write back results:
999980#pragma  unroll
1000981    for  (int  jc0 = 0 ; jc0 < cpw; ++jc0) {
@@ -1007,6 +988,8 @@ static __global__ void flash_attn_tile(
1007988            return ;
1008989        }
1009990
991+         const  float  scale = gridDim .y  == 1  ? 1 .0f /KQ_sum[jc0] : 1 .0f ;
992+ 
1010993        const  int  j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim .y  + blockIdx .y ;
1011994
1012995#ifdef  FAST_FP16_AVAILABLE
@@ -1017,6 +1000,8 @@ static __global__ void flash_attn_tile(
10171000#pragma  unroll
10181001            for  (int  i1 = 0 ; i1 < cpy_ne_D; ++i1) {
10191002                tmp[i1] = __half22float2 (VKQ[jc0*((DVp/2 )/warp_size) + i0/warp_size + i1]);
1003+                 tmp[i1].x  *= scale;
1004+                 tmp[i1].y  *= scale;
10201005            }
10211006            if  (i0 + warp_size*cpy_ne_D <= DV/2  || i0 + threadIdx .x *cpy_ne_D < DV/2 ) {
10221007                ggml_cuda_memcpy_1<sizeof (tmp)>(&dst[j_dst_unrolled*DV + 2 *i0 + threadIdx .x *(2 *cpy_ne_D)], tmp);
@@ -1027,6 +1012,11 @@ static __global__ void flash_attn_tile(
10271012#pragma  unroll
10281013        for  (int  i0 = 0 ; i0 < DVp; i0 += warp_size*cpy_ne_D) {
10291014            if  (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx .x *cpy_ne_D < DV) {
1015+ #pragma  unroll
1016+                 for  (int  i1 = 0 ; i1 < cpy_ne_D/2 ; ++i1) {
1017+                     VKQ[jc0*((DVp/2 )/warp_size) + i0/(2 *warp_size) + i1].x  *= scale;
1018+                     VKQ[jc0*((DVp/2 )/warp_size) + i0/(2 *warp_size) + i1].y  *= scale;
1019+                 }
10301020                ggml_cuda_memcpy_1<cpy_ne_D*4 >(
10311021                    &dst[j_dst_unrolled*DV + i0 + threadIdx .x *cpy_ne_D],
10321022                    &VKQ[jc0*((DVp/2 )/warp_size) + i0/(2 *warp_size)]);
0 commit comments