From 04adc0c8b1f0ff918baa51a91c450a45b762405f Mon Sep 17 00:00:00 2001 From: Linus Gubenis <7453396+linus-amg@users.noreply.github.com> Date: Fri, 23 Jan 2026 23:22:56 -0600 Subject: [PATCH] HIP: Enable MMA flash attention for RDNA3 with head size 576 This enables MMA-based flash attention on RDNA3 GPUs (gfx1100/1101/1102) for models with head size 576, such as GLM-4.7-Flash and other MLA (Multi-head Latent Attention) models. Previously, flash attention with head size 576 only worked on CUDA (via PR #18953) and RDNA4. RDNA3 users had to disable flash attention, resulting in ~3x slower inference. Changes: - fattn.cu: Route RDNA3 + head size 576 to MMA kernel (was RDNA4-only) - fattn-mma-f16.cuh: Enable AMD WMMA for all RDNA3/RDNA4, allow DKQ==576 - mma.cuh: Add RDNA3 to make_identity_mat(), add f16->f16 WMMA intrinsic Tested on AMD RX 7900 XTX (gfx1100) with GLM-4.7-Flash-REAP-23B: - FA off: ~77 t/s - FA on (before, broken): ~27 t/s - FA on (after fix): ~83 t/s --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 14 +++++++------- ggml/src/ggml-cuda/fattn.cu | 2 +- ggml/src/ggml-cuda/mma.cuh | 12 +++++++++--- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 203569e3459..a4fa043278a 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -428,7 +428,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int jt, const int kb0, const int k_VKQ_sup) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); @@ -881,7 +881,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } #if defined(TURING_MMA_AVAILABLE) @@ -944,7 +944,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jt, const int kb0_start, const int kb0_stop) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr int ncols = ncols1 * ncols2; @@ -1454,7 +1454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -1481,7 +1481,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1503,7 +1503,7 @@ static __global__ void flash_attn_ext_f16( #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING #if defined(AMD_WMMA_AVAILABLE) - if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) { + if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || (DKQ > 128 && DKQ != 576) || ncols2 == 1) { NO_DEVICE_CODE; return; } @@ -1622,7 +1622,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)) } template diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 87f07a2f938..dbb66e48cc2 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -352,7 +352,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_WMMA_F16; } - if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) { + if (amd_wmma_available(cc) && (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) && gqa_opt_applies && (Q->ne[0] <= 128 || Q->ne[0] == 576) && Q->ne[0] != 40 && Q->ne[0] != 72) { if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (Q->ne[1] == 1) { diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 42085d10027..45c44722bd8 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -661,7 +661,7 @@ namespace ggml_cuda_mma { #endif // defined(TURING_MMA_AVAILABLE) static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) { -#if defined(RDNA4) +#if defined(RDNA4) || defined(RDNA3) const int row = t.get_i(0); const int left_right = t.get_j(0) / 4; const int up_down = row / 8; @@ -670,7 +670,7 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(t); NO_DEVICE_CODE; -#endif // defined(RDNA4) +#endif // defined(RDNA4) || defined(RDNA3) } template @@ -919,10 +919,16 @@ namespace ggml_cuda_mma { const halfx8_t& a_frag = reinterpret_cast(A.x[0]); const halfx8_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); +#elif defined(RDNA3) + using halfx16_t = __attribute__((ext_vector_type(16))) _Float16; + halfx16_t& acc_frag = reinterpret_cast(D.x[0]); + const halfx16_t& a_frag = reinterpret_cast(A.x[0]); + const halfx16_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag, b_frag, acc_frag, false); #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // defined(RDNA4) +#endif // defined(RDNA4) || defined(RDNA3) #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE;