Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,46 @@ typedef void (* fattn_kernel_t)(
typedef float (*vec_dot_KQ_t)(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);

struct ggml_cuda_flash_attn_ext_f16_extra_data {
uintptr_t K;
uintptr_t V;
uintptr_t end;
};

static inline ggml_cuda_flash_attn_ext_f16_extra_data ggml_cuda_flash_attn_ext_get_f16_extra_data(
const ggml_tensor * dst, const bool need_f16_K, const bool need_f16_V) {
GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT);

const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];

GGML_ASSERT(K != nullptr);
GGML_ASSERT(V != nullptr);

const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));

ggml_cuda_flash_attn_ext_f16_extra_data data = {};
data.end = (uintptr_t) dst->data + ggml_nbytes(dst);

if (need_f16_K && K->type != GGML_TYPE_F16) {
data.end = GGML_PAD(data.end, 128);
data.K = data.end;
data.end += ggml_nelements(K)*ggml_type_size(GGML_TYPE_F16);
}

if (need_f16_V && V->type != GGML_TYPE_F16) {
if (V_is_K_view) {
data.V = data.K;
} else {
data.end = GGML_PAD(data.end, 128);
data.V = data.end;
data.end += ggml_nelements(V)*ggml_type_size(GGML_TYPE_F16);
}
}

return data;
}

template <int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
Expand Down Expand Up @@ -952,8 +992,9 @@ void launch_fattn(
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;

ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra =
ggml_cuda_flash_attn_ext_get_f16_extra_data(KQV, need_f16_K, need_f16_V);

ggml_cuda_pool_alloc<int> KV_max(pool);
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
Expand All @@ -972,10 +1013,11 @@ void launch_fattn(
const size_t bs = ggml_blck_size(K->type);
const size_t ts = ggml_type_size(K->type);

K_f16.alloc(ggml_nelements(K));
GGML_ASSERT(f16_extra.K != 0);
half * K_f16 = (half *) f16_extra.K;
if (ggml_is_contiguously_allocated(K)) {
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
to_fp16(K_data, K_f16, ggml_nelements(K), main_stream);

nb11 = nb11*bs*sizeof(half)/ts;
nb12 = nb12*bs*sizeof(half)/ts;
Expand All @@ -986,13 +1028,13 @@ void launch_fattn(
const int64_t s01 = nb11 / ts;
const int64_t s02 = nb12 / ts;
const int64_t s03 = nb13 / ts;
to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
to_fp16(K_data, K_f16, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);

nb11 = K->ne[0] * sizeof(half);
nb12 = K->ne[1] * nb11;
nb13 = K->ne[2] * nb12;
}
K_data = (char *) K_f16.ptr;
K_data = (char *) K_f16;
}

if (need_f16_V && V->type != GGML_TYPE_F16) {
Expand All @@ -1005,11 +1047,12 @@ void launch_fattn(
const size_t bs = ggml_blck_size(V->type);
const size_t ts = ggml_type_size(V->type);

V_f16.alloc(ggml_nelements(V));
GGML_ASSERT(f16_extra.V != 0);
half * V_f16 = (half *) f16_extra.V;
if (ggml_is_contiguously_allocated(V)) {
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
V_data = (char *) V_f16.ptr;
to_fp16(V_data, V_f16, ggml_nelements(V), main_stream);
V_data = (char *) V_f16;

nb21 = nb21*bs*sizeof(half)/ts;
nb22 = nb22*bs*sizeof(half)/ts;
Expand All @@ -1020,13 +1063,13 @@ void launch_fattn(
const int64_t s01 = nb21 / ts;
const int64_t s02 = nb22 / ts;
const int64_t s03 = nb23 / ts;
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
to_fp16(V_data, V_f16, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);

nb21 = V->ne[0] * sizeof(half);
nb22 = V->ne[1] * nb21;
nb23 = V->ne[2] * nb22;
}
V_data = (char *) V_f16.ptr;
V_data = (char *) V_f16;
}
}

Expand Down
35 changes: 35 additions & 0 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,41 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_TILE;
}

size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst) {
GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT);

Comment thread
JohannesGaessler marked this conversation as resolved.
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];

GGML_ASSERT(K != nullptr);
GGML_ASSERT(V != nullptr);

const best_fattn_kernel kernel = ggml_cuda_get_best_fattn_kernel(device, dst);

bool need_f16_K = false;
bool need_f16_V = false;

switch (kernel) {
case BEST_FATTN_KERNEL_TILE:
case BEST_FATTN_KERNEL_WMMA_F16:
case BEST_FATTN_KERNEL_MMA_F16:
need_f16_K = true;
need_f16_V = true;
break;
case BEST_FATTN_KERNEL_VEC:
need_f16_K = K->type == GGML_TYPE_F32;
need_f16_V = V->type == GGML_TYPE_F32;
break;
case BEST_FATTN_KERNEL_NONE:
break;
}

const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra =
ggml_cuda_flash_attn_ext_get_f16_extra_data(dst, need_f16_K, need_f16_V);

return f16_extra.end - (uintptr_t) dst->data;
}

void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_set_device(ctx.device);
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/fattn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst);

size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst);
8 changes: 5 additions & 3 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,11 @@ static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_ty
}

static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
size_t size = ggml_nbytes(tensor);
ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *) buft->context;

size_t size = tensor->op == GGML_OP_FLASH_ATTN_EXT
? ggml_cuda_flash_attn_ext_get_alloc_size(buft_ctx->device, tensor)
: ggml_nbytes(tensor);
int64_t ne0 = tensor->ne[0];

if (ggml_is_quantized(tensor->type)) {
Expand All @@ -812,8 +816,6 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t
}

return size;

GGML_UNUSED(buft);
}

static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
Expand Down
Loading