Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,20 @@ static __global__ void reduce_rows_f32(const float * x, float * dst, const int n
dst[row] = norm ? sum / ncols : sum;
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_all(int x) {
#ifdef GGML_USE_HIP
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
}
return x;
#else
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
return __all_sync(0xffffffff, x);
#endif // GGML_USE_HIP
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
Expand Down
75 changes: 73 additions & 2 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
Expand Down Expand Up @@ -500,6 +501,55 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
nullptr;
}

template <int ncols1>
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
static __global__ void flash_attn_mask_to_KV_max(
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
const int ne31 = gridDim.x;
const int tid = threadIdx.x;
const int sequence = blockIdx.y;
const int jt = blockIdx.x;

mask += sequence*s33 + jt*ncols1*s31;

__shared__ int buf_iw[WARP_SIZE];
if (tid < WARP_SIZE) {
buf_iw[tid] = 1;
}
__syncthreads();

int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
int all_inf = 1;

#pragma unroll
for (int j = 0; j < ncols1; ++j) {
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
}

all_inf = warp_reduce_all(all_inf);
if (tid % WARP_SIZE == 0) {
buf_iw[tid / WARP_SIZE] = all_inf;
}
__syncthreads();
all_inf = buf_iw[tid % WARP_SIZE];
__syncthreads();
all_inf = warp_reduce_all(all_inf);

if (!all_inf) {
KV_max_sj += FATTN_KQ_STRIDE;
break;
}
}

if (threadIdx.x != 0) {
return;
}

KV_max[sequence*ne31 + jt] = KV_max_sj;
}

template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
Expand Down Expand Up @@ -711,6 +761,7 @@ void launch_fattn(

ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
ggml_cuda_pool_alloc<int> KV_max(pool);
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);

Expand Down Expand Up @@ -779,11 +830,30 @@ void launch_fattn(
V_data = (char *) V_f16.ptr;
}

int parallel_blocks = 1;

const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];

// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
// multiple sequences of possibly different lengths.
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
const int s31 = mask->nb[1] / sizeof(half2);
const int s33 = mask->nb[3] / sizeof(half2);

const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);

const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;

KV_max.alloc(ne_KV_max);
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
CUDA_CHECK(cudaGetLastError());
}

int parallel_blocks = 1;

const dim3 block_dim(warp_size, nwarps, 1);
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
Expand Down Expand Up @@ -870,6 +940,7 @@ void launch_fattn(
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
KV_max.ptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
Expand Down
21 changes: 16 additions & 5 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
}
}

template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles,
bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
Expand Down Expand Up @@ -922,7 +923,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}

// Iterate over ne11 == previous tokens:
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
int kb0 = kb0_start;
for (; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
Expand All @@ -932,7 +934,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr bool last_iter = true;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
}

// With multi-stage loading there is no __syncthreads at the end of the iter,
Expand Down Expand Up @@ -1204,6 +1206,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
Expand Down Expand Up @@ -1280,7 +1283,11 @@ static __global__ void flash_attn_ext_f16(
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;

const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;

if (KV_max) {
kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
}

constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
if (kb0_start == 0) {
Expand Down Expand Up @@ -1321,7 +1328,11 @@ static __global__ void flash_attn_ext_f16(
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;

const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;

if (KV_max) {
kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
}

constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
Expand Down
4 changes: 3 additions & 1 deletion ggml/src/ggml-cuda/fattn-tile-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
Expand Down Expand Up @@ -90,7 +91,8 @@ static __global__ void flash_attn_tile_ext_f16(

__syncthreads();

for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
// Calculate KQ tile and keep track of new maximum KQ values:

half kqmax_new[ncols/nwarps];
Expand Down
4 changes: 3 additions & 1 deletion ggml/src/ggml-cuda/fattn-tile-f32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
Expand Down Expand Up @@ -99,7 +100,8 @@ static __global__ void flash_attn_tile_ext_f32(

__syncthreads();

for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
// Calculate KQ tile and keep track of new maximum KQ values:

float kqmax_new[ncols/nwarps];
Expand Down
26 changes: 3 additions & 23 deletions ggml/src/ggml-cuda/fattn-vec-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
Expand Down Expand Up @@ -171,10 +172,11 @@ static __global__ void flash_attn_vec_ext_f16(

half2 VKQ[ncols] = {{0.0f, 0.0f}};

const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
K += blockIdx.y*D * nb11;
V += blockIdx.y*D * nb21;
maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
// Increment pointers after each loop:
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {

Expand All @@ -185,29 +187,7 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
}

__syncthreads();

// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
// In such cases, skip the KV slice.
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
#ifndef GGML_USE_HIP
bool skip = true;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;

const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
skip = skip && isinf(tmp.x) && isinf(tmp.y);
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
}
#endif // GGML_USE_HIP
}

// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
Expand Down
25 changes: 3 additions & 22 deletions ggml/src/ggml-cuda/fattn-vec-f32.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ static __global__ void flash_attn_vec_ext_f32(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
Expand Down Expand Up @@ -177,10 +178,11 @@ static __global__ void flash_attn_vec_ext_f32(

float VKQ[ncols] = {0.0f};

const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
K += blockIdx.y*D * nb11;
V += blockIdx.y*D * nb21;
maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
// Increment pointers after each loop:
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {

Expand All @@ -191,28 +193,7 @@ static __global__ void flash_attn_vec_ext_f32(
for (int j = 0; j < ncols; ++j) {
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
}

__syncthreads();

// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
// In such cases, skip the KV slice.
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
#ifndef GGML_USE_HIP
bool skip = true;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;

skip = skip && isinf(maskf_shared[j*D + i]);
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
}
#endif // GGML_USE_HIP
}

float kqmax_new_arr[ncols];
Expand Down
4 changes: 3 additions & 1 deletion ggml/src/ggml-cuda/fattn-wmma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
Expand Down Expand Up @@ -165,7 +166,8 @@ static __global__ void flash_attn_ext_f16(
__syncthreads();

// Iterate over ne11 == previous tokens:
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
Expand Down
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst

const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies &&
(Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
if (prec == GGML_PREC_DEFAULT) {
Expand Down
Loading