Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int jt,
const int kb0,
const int k_VKQ_sup) {
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = get_cols_per_thread();
Expand Down Expand Up @@ -881,7 +881,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
tile_Q, tile_K, tile_V, tile_mask,
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
NO_DEVICE_CODE;
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}

#if defined(TURING_MMA_AVAILABLE)
Expand Down Expand Up @@ -944,7 +944,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int jt,
const int kb0_start,
const int kb0_stop) {
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

constexpr int ncols = ncols1 * ncols2;
Expand Down Expand Up @@ -1454,7 +1454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
jt, kb0_start, kb0_stop);
NO_DEVICE_CODE;
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}

template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
Expand All @@ -1481,7 +1481,7 @@ static __global__ void flash_attn_ext_f16(
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))

// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
Expand All @@ -1503,7 +1503,7 @@ static __global__ void flash_attn_ext_f16(
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING

#if defined(AMD_WMMA_AVAILABLE)
if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || (DKQ > 128 && DKQ != 576) || ncols2 == 1) {
NO_DEVICE_CODE;
return;
}
Expand Down Expand Up @@ -1622,7 +1622,7 @@ static __global__ void flash_attn_ext_f16(
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))
}

template <int DKQ, int DV, int ncols1, int ncols2>
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_WMMA_F16;
}

if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
if (amd_wmma_available(cc) && (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) && gqa_opt_applies && (Q->ne[0] <= 128 || Q->ne[0] == 576) && Q->ne[0] != 40 && Q->ne[0] != 72) {
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (Q->ne[1] == 1) {
Expand Down
12 changes: 9 additions & 3 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ namespace ggml_cuda_mma {
#endif // defined(TURING_MMA_AVAILABLE)

static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
#if defined(RDNA4)
#if defined(RDNA4) || defined(RDNA3)
const int row = t.get_i(0);
const int left_right = t.get_j(0) / 4;
const int up_down = row / 8;
Expand All @@ -670,7 +670,7 @@ namespace ggml_cuda_mma {
#else
GGML_UNUSED_VARS(t);
NO_DEVICE_CODE;
#endif // defined(RDNA4)
#endif // defined(RDNA4) || defined(RDNA3)
}

template <int I, int J, typename T, data_layout dl>
Expand Down Expand Up @@ -919,10 +919,16 @@ namespace ggml_cuda_mma {
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
#elif defined(RDNA3)
using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
halfx16_t& acc_frag = reinterpret_cast<halfx16_t&>(D.x[0]);
const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag, b_frag, acc_frag, false);
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // defined(RDNA4)
#endif // defined(RDNA4) || defined(RDNA3)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
Expand Down