Skip to content

Commit 7049736

Browse files
CUDA: fix numerical issues in tile FA kernel (#16540)
1 parent 01d2bdc commit 7049736

File tree

1 file changed

+17
-27
lines changed

1 file changed

+17
-27
lines changed

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -540,10 +540,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
540540
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
541541
}
542542

543-
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ?
544-
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
543+
if (!oob_check || i_KQ < k_VKQ_sup) {
544+
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
545+
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
545546

546-
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
547+
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
548+
}
547549
}
548550

549551
KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
@@ -581,10 +583,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
581583
float KQ_sum_add = 0.0f;
582584
#pragma unroll
583585
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
584-
const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]);
585-
if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) {
586-
KQ_sum_add += val;
587-
}
586+
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
587+
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
588+
KQ_sum_add += val;
588589
tmp[i0/(np*warp_size)][jc1] = val;
589590
}
590591
KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
@@ -975,26 +976,6 @@ static __global__ void flash_attn_tile(
975976
}
976977
}
977978

978-
if (gridDim.y == 1) {
979-
#pragma unroll
980-
for (int jc0 = 0; jc0 < cpw; ++jc0) {
981-
#ifdef FAST_FP16_AVAILABLE
982-
const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]);
983-
#pragma unroll
984-
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
985-
VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv;
986-
}
987-
#else
988-
const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0];
989-
#pragma unroll
990-
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
991-
VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv;
992-
VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv;
993-
}
994-
#endif // FAST_FP16_AVAILABLE
995-
}
996-
}
997-
998979
// Write back results:
999980
#pragma unroll
1000981
for (int jc0 = 0; jc0 < cpw; ++jc0) {
@@ -1007,6 +988,8 @@ static __global__ void flash_attn_tile(
1007988
return;
1008989
}
1009990

991+
const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
992+
1010993
const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
1011994

1012995
#ifdef FAST_FP16_AVAILABLE
@@ -1017,6 +1000,8 @@ static __global__ void flash_attn_tile(
10171000
#pragma unroll
10181001
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
10191002
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
1003+
tmp[i1].x *= scale;
1004+
tmp[i1].y *= scale;
10201005
}
10211006
if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
10221007
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
@@ -1027,6 +1012,11 @@ static __global__ void flash_attn_tile(
10271012
#pragma unroll
10281013
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
10291014
if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
1015+
#pragma unroll
1016+
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
1017+
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
1018+
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
1019+
}
10301020
ggml_cuda_memcpy_1<cpy_ne_D*4>(
10311021
&dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
10321022
&VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);

0 commit comments

Comments
 (0)