@@ -540,12 +540,10 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
540540 KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf (KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
541541 }
542542
543- if (!oob_check || i_KQ < k_VKQ_sup) {
544- KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
545- slope*__half2float (mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0 .0f ;
543+ KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ?
544+ slope*__half2float (mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0 .0f ;
546545
547- KQ_max_new[jc0] = fmaxf (KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
548- }
546+ KQ_max_new[jc0] = fmaxf (KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
549547 }
550548
551549 KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
@@ -583,9 +581,10 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
583581 float KQ_sum_add = 0 .0f ;
584582#pragma unroll
585583 for (int i0 = 0 ; i0 < nbatch_fa; i0 += np*warp_size) {
586- const float val = !oob_check || i0 + (threadIdx .y % np)*warp_size + threadIdx .x < k_VKQ_sup ?
587- expf (KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0 .0f ;
588- KQ_sum_add += val;
584+ const float val = expf (KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]);
585+ if (!oob_check || i0 + (threadIdx .y % np)*warp_size + threadIdx .x < k_VKQ_sup) {
586+ KQ_sum_add += val;
587+ }
589588 tmp[i0/(np*warp_size)][jc1] = val;
590589 }
591590 KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
@@ -976,6 +975,26 @@ static __global__ void flash_attn_tile(
976975 }
977976 }
978977
978+ if (gridDim .y == 1 ) {
979+ #pragma unroll
980+ for (int jc0 = 0 ; jc0 < cpw; ++jc0) {
981+ #ifdef FAST_FP16_AVAILABLE
982+ const half2 KQ_sum_jc_inv = make_half2 (1 .0f /KQ_sum[jc0], 1 .0f /KQ_sum[jc0]);
983+ #pragma unroll
984+ for (int i = 0 ; i < (DVp/2 )/warp_size; ++i) {
985+ VKQ[jc0*((DVp/2 )/warp_size) + i] *= KQ_sum_jc_inv;
986+ }
987+ #else
988+ const float KQ_sum_jc_inv = 1 .0f /KQ_sum[jc0];
989+ #pragma unroll
990+ for (int i = 0 ; i < (DVp/2 )/warp_size; ++i) {
991+ VKQ[jc0*((DVp/2 )/warp_size) + i].x *= KQ_sum_jc_inv;
992+ VKQ[jc0*((DVp/2 )/warp_size) + i].y *= KQ_sum_jc_inv;
993+ }
994+ #endif // FAST_FP16_AVAILABLE
995+ }
996+ }
997+
979998 // Write back results:
980999#pragma unroll
9811000 for (int jc0 = 0 ; jc0 < cpw; ++jc0) {
@@ -988,8 +1007,6 @@ static __global__ void flash_attn_tile(
9881007 return ;
9891008 }
9901009
991- const float scale = gridDim .y == 1 ? 1 .0f /KQ_sum[jc0] : 1 .0f ;
992-
9931010 const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim .y + blockIdx .y ;
9941011
9951012#ifdef FAST_FP16_AVAILABLE
@@ -1000,8 +1017,6 @@ static __global__ void flash_attn_tile(
10001017#pragma unroll
10011018 for (int i1 = 0 ; i1 < cpy_ne_D; ++i1) {
10021019 tmp[i1] = __half22float2 (VKQ[jc0*((DVp/2 )/warp_size) + i0/warp_size + i1]);
1003- tmp[i1].x *= scale;
1004- tmp[i1].y *= scale;
10051020 }
10061021 if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx .x *cpy_ne_D < DV/2 ) {
10071022 ggml_cuda_memcpy_1<sizeof (tmp)>(&dst[j_dst_unrolled*DV + 2 *i0 + threadIdx .x *(2 *cpy_ne_D)], tmp);
@@ -1012,11 +1027,6 @@ static __global__ void flash_attn_tile(
10121027#pragma unroll
10131028 for (int i0 = 0 ; i0 < DVp; i0 += warp_size*cpy_ne_D) {
10141029 if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx .x *cpy_ne_D < DV) {
1015- #pragma unroll
1016- for (int i1 = 0 ; i1 < cpy_ne_D/2 ; ++i1) {
1017- VKQ[jc0*((DVp/2 )/warp_size) + i0/(2 *warp_size) + i1].x *= scale;
1018- VKQ[jc0*((DVp/2 )/warp_size) + i0/(2 *warp_size) + i1].y *= scale;
1019- }
10201030 ggml_cuda_memcpy_1<cpy_ne_D*4 >(
10211031 &dst[j_dst_unrolled*DV + i0 + threadIdx .x *cpy_ne_D],
10221032 &VKQ[jc0*((DVp/2 )/warp_size) + i0/(2 *warp_size)]);
0 commit comments