Skip to content

Commit a2b5057

Browse files
ikawrakowIwan Kawrakow
andauthored
Bug fixes from mainline (#439)
* Add __syncthreads() to the new FA kernel * Clearing padding --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 2ec2229 commit a2b5057

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

ggml/include/ggml-backend.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ extern "C" {
2222
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
2323
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
2424
GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft);
25-
GGML_API GGML_CALL size_t ggml_backend_buft_get_alloc_size (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
25+
GGML_API GGML_CALL size_t ggml_backend_buft_get_alloc_size (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
2626
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
2727

2828
// buffer
@@ -39,7 +39,7 @@ extern "C" {
3939
GGML_API GGML_CALL void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
4040
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
4141
GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer);
42-
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
42+
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor);
4343
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
4444
GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
4545
GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);

ggml/src/ggml-backend.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
3535
return SIZE_MAX;
3636
}
3737

38-
GGML_CALL size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
38+
GGML_CALL size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
3939
// get_alloc_size is optional, defaults to ggml_nbytes
4040
if (buft->iface.get_alloc_size) {
4141
size_t size = buft->iface.get_alloc_size(buft, tensor);
@@ -114,7 +114,7 @@ size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) {
114114
return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer));
115115
}
116116

117-
size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
117+
size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor) {
118118
return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);
119119
}
120120

ggml/src/ggml-cuda.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2101,13 +2101,19 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
21012101
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
21022102
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
21032103

2104+
// If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q.
2105+
// But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data.
2106+
// Therefore, in such cases use cuBLAS.
2107+
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
2108+
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
2109+
21042110
bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
21052111
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
21062112
&& src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
2107-
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
2113+
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
21082114
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
21092115
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
2110-
bool use_mul_mat_q = ggml_is_quantized(src0->type)
2116+
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
21112117
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
21122118

21132119
// if mmvq is available it's a better choice than dmmv:

ggml/src/ggml-cuda/fattn-new-mma.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
10931093
}
10941094
}
10951095

1096+
// do we really need this?
1097+
__syncthreads();
1098+
10961099
// Write back combined meta data:
10971100
#pragma unroll
10981101
for (int imeta = 0; imeta < nmeta; ++imeta) {
@@ -1112,6 +1115,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
11121115
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
11131116
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
11141117
}
1118+
} else if (np > 1) {
1119+
// Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
1120+
// Therefore, all other warps also need to execute a __syncthreads().
1121+
// Otherwise the points at which warps synchronize with each other would become misaligned.
1122+
__syncthreads();
11151123
}
11161124

11171125
#pragma unroll

0 commit comments

Comments
 (0)