Skip to content

Commit 651f696

Browse files
committed
Revert "CUDA: fix numerical issues in tile FA kernel (ggml-org#16540)"
This reverts commit 7049736.
1 parent 7318662 commit 651f696

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -540,12 +540,10 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
540540
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
541541
}
542542

543-
if (!oob_check || i_KQ < k_VKQ_sup) {
544-
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
545-
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
543+
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ?
544+
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
546545

547-
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
548-
}
546+
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
549547
}
550548

551549
KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
@@ -583,9 +581,10 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
583581
float KQ_sum_add = 0.0f;
584582
#pragma unroll
585583
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
586-
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
587-
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
588-
KQ_sum_add += val;
584+
const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]);
585+
if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) {
586+
KQ_sum_add += val;
587+
}
589588
tmp[i0/(np*warp_size)][jc1] = val;
590589
}
591590
KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
@@ -976,6 +975,26 @@ static __global__ void flash_attn_tile(
976975
}
977976
}
978977

978+
if (gridDim.y == 1) {
979+
#pragma unroll
980+
for (int jc0 = 0; jc0 < cpw; ++jc0) {
981+
#ifdef FAST_FP16_AVAILABLE
982+
const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]);
983+
#pragma unroll
984+
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
985+
VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv;
986+
}
987+
#else
988+
const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0];
989+
#pragma unroll
990+
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
991+
VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv;
992+
VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv;
993+
}
994+
#endif // FAST_FP16_AVAILABLE
995+
}
996+
}
997+
979998
// Write back results:
980999
#pragma unroll
9811000
for (int jc0 = 0; jc0 < cpw; ++jc0) {
@@ -988,8 +1007,6 @@ static __global__ void flash_attn_tile(
9881007
return;
9891008
}
9901009

991-
const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
992-
9931010
const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
9941011

9951012
#ifdef FAST_FP16_AVAILABLE
@@ -1000,8 +1017,6 @@ static __global__ void flash_attn_tile(
10001017
#pragma unroll
10011018
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
10021019
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
1003-
tmp[i1].x *= scale;
1004-
tmp[i1].y *= scale;
10051020
}
10061021
if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
10071022
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
@@ -1012,11 +1027,6 @@ static __global__ void flash_attn_tile(
10121027
#pragma unroll
10131028
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
10141029
if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
1015-
#pragma unroll
1016-
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
1017-
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
1018-
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
1019-
}
10201030
ggml_cuda_memcpy_1<cpy_ne_D*4>(
10211031
&dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
10221032
&VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);

0 commit comments

Comments
 (0)