From d79105aa61e88d02f3376e9825c889bdee965e85 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 28 Sep 2025 10:15:03 -0400 Subject: [PATCH 1/4] Fix load_fragment --- .../include/gpu_iface/backend/hip/mma_hip.h | 99 +++++++++---------- .../tests/hip/test_mfma_fp32_16x16x16fp16.cpp | 32 +++--- 2 files changed, 61 insertions(+), 70 deletions(-) diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index e6c774d1c4..7429e40966 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -37,34 +37,32 @@ __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t* R) { uint32_t lane_in_group = lane_id % 4; // === ROUND 1: Exchange with neighbor (XOR with 1) === - // T0↔T1, T2↔T3 partial exchange - uint32_t reg_idx = (lane_in_group >> 1) & 0x1; - uint32_t exchanged_val = __shfl_xor(R[reg_idx], 0x1); + // T0 <-> T1, T2 <-> T3 partial exchange + uint32_t regid = (lane_in_group >> 1) & 0x1; + uint32_t exchanged_val = __shfl_xor(R[regid], 0x1); uint32_t shift = (lane_in_group & 1) * 16; - uint32_t keep_mask = 0xFFFF0000 >> shift; - int right_shift_amount = 16 * (1 - (lane_in_group & 1)); - int left_shift_amount = 16 * (lane_in_group & 1); - R[reg_idx] = - (R[reg_idx] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); + uint32_t keep_mask = 0x0000FFFF << shift; + int left_shift_amount = 16 * (1 - (lane_in_group & 1)); + int right_shift_amount = 16 * (lane_in_group & 1); + R[regid] = (R[regid] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); // === ROUND 2: Exchange with one hop (XOR with 2) === - // T0↔T2, T1↔T3 exchange R[0] and R[1] + // T0 <-> T2, T1 <-> T3 exchange R[0] and R[1] // Swap entire registers based on thread position - uint32_t is_top = 1 - reg_idx; + uint32_t is_top = 1 - regid; uint32_t temp0 = __shfl_xor(R[0], 0x2); uint32_t temp1 = __shfl_xor(R[1], 0x2); // Compute both possibilities and select - R[0] = R[0] * is_top + temp1 * reg_idx; - R[1] = temp0 * is_top + R[1] * reg_idx; + R[0] = R[0] * is_top + temp1 * regid; + R[1] = temp0 * is_top + R[1] * regid; // === ROUND 3: Exchange with neighbor again (XOR with 1) === - // T0↔T1, T2↔T3 exchange remaining parts + // T0 <-> T1, T2 <-> T3 exchange remaining parts - reg_idx = 1 - reg_idx; - exchanged_val = __shfl_xor(R[reg_idx], 0x1); - R[reg_idx] = - (R[reg_idx] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); + regid = 1 - regid; + exchanged_val = __shfl_xor(R[regid], 0x1); + R[regid] = (R[regid] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); } // Single unified load function for all fragment types @@ -72,25 +70,8 @@ __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t* R) { /// @param smem_ptr [in] pointer to the shared memory to load the fragment from template __device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) { - const uint16_t* v0 = reinterpret_cast(smem_ptr) + 0; - const uint16_t* v1 = reinterpret_cast(++smem_ptr); - const uint16_t* v2 = reinterpret_cast(++smem_ptr); - const uint16_t* v3 = reinterpret_cast(++smem_ptr); - - R[0] = (static_cast(*v0) << 16) | static_cast(*v1); - R[1] = (static_cast(*v2) << 16) | static_cast(*v3); -} - -template -__device__ __forceinline__ void load_fragment_transpose(uint32_t* R, const T* smem_ptr, - uint32_t stride) { - const uint16_t* v0 = reinterpret_cast(smem_ptr) + 0; - const uint16_t* v1 = reinterpret_cast(smem_ptr + 1 * stride); - const uint16_t* v2 = reinterpret_cast(smem_ptr + 2 * stride); - const uint16_t* v3 = reinterpret_cast(smem_ptr + 3 * stride); - - R[0] = (static_cast(*v0) << 16) | static_cast(*v1); - R[1] = (static_cast(*v2) << 16) | static_cast(*v3); + R[0] = reinterpret_cast(smem_ptr)[0]; + R[1] = reinterpret_cast(smem_ptr)[1]; } // MMA operation for FP16 inputs with FP32 accumulator @@ -121,36 +102,46 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u C[3] = C_fp32[3]; } -/// Loads a fragment from LDS to two 32bit registers and then transposes +/// @brief Loads a fragment from LDS to two 32bit registers and then transposes /// the registers for a group of four consecuitive threads. +/// +/// transposes the values in four adjacent threads. The function does the +/// following layout transformation: +/// Original data in registers for Threads 0-3 after fragment load +/// T0 : a b c d +/// T1 : e f g h +/// T2 : i j k l +/// T3 : m n o p +/// +/// After transposition: +/// T0 : a e i m +/// T1 : b f j n +/// T2 : c g k o +/// T3 : d h l p template __device__ __forceinline__ void load_fragment_4x4_half_registers(uint32_t* R, const T* smem_ptr) { - static_assert(std::is_same_v(), "Only half type is supported"); - // Each thread loads 4 __half values in two 32b registers. + static_assert(std::is_same_v, "Only half type is supported"); load_fragment(R, smem_ptr); - // transposes the values in four adjacent threads. The function does the - // following layout transformation: - // Original data in registers for Threads 0-3 after fragment load - // T0 : a b c d - // T1 : e f g h - // T2 : i j k l - // T3 : m n o p - // - // After transposition: - // T0 : a e i m - // T1 : b f j n - // T2 : c g k o - // T3 : d h l p - transpose_4x4_half_registers(R); } +// TODO: Verify correct matrix multiplication order for rowsum on CDNA3 +// Current assumption: s_frag × ones_vector = row_sums +// Need to validate: +// 1. How compute_qk stores Q×K^T result in s_frag for CDNA3 +// 2. Whether K is pre-transposed or transposed during fragment loading +// 3. If we need s_frag × M1 or M1 × s_frag for correct row sums +// +// Test with known input matrices to verify: +// - s_frag layout matches expected Q×K^T result +// - rowsum produces correct per-row sums template __device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s_frag) { static_assert(sizeof(DType) == 2, "DType must be 16-bit type"); + transpose_4x4_half_registers(reinterpret_cast(s_frag)); f16x4 a = reinterpret_cast(s_frag)[0]; f16x4 b = {f16(1.0f), f16(1.0f), f16(1.0f), f16(1.0f)}; - f32x4 c = {0.f, 0.f, 0.f, 0.f}; + f32x4 c = {d[0], d[1], d[2], d[3]}; f32x4 out = __builtin_amdgcn_mfma_f32_16x16x16f16(a, b, c, 0, 0, 0); d[0] = out.x; d[1] = out.y; diff --git a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp index 24eb5219af..4128c7f6cb 100644 --- a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp +++ b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp @@ -24,6 +24,18 @@ } \ } +namespace { + +__device__ void print_register(uint32_t* R) { + auto values = reinterpret_cast<__half*>(R); + printf("[%f %f %f %f]\n", __half2float(values[0]), __half2float(values[1]), + __half2float(values[2]), __half2float(values[3])); +} + +__device__ void print_register(float* R) { printf("[%f %f %f %f]\n", R[0], R[1], R[3], R[4]); } + +} // namespace + // Dimensions for our test matrices constexpr int M = 16; constexpr int N = 16; @@ -41,7 +53,6 @@ void gemm_reference(const __half* A, const __half* B, float* C, int M, int N, in for (int j = 0; j < N; ++j) { float acc = 0.0f; for (int k = 0; k < K; ++k) { - // Use __half_as_float to properly convert __half to float acc += __half2float(A[i * K + k]) * __half2float(B[k * N + j]); } C[i * N + j] = acc; @@ -54,26 +65,15 @@ __global__ void test_mfma_kernel(const __half* A, const __half* B, float* C) { uint32_t b_reg[2]; float c_reg[4] = {0.0f, 0.0f, 0.0f, 0.0f}; - // A Matrix is read row wise. Threads T0...T15 read Col 0...3 of Row 0...15 - // Threads T16...T31 read Col 4...7 of Row 0...15 - // Threads T32...T47 read Col 8...11 of Row 0...15 - // Threads T48...T63 read Col 12...15 of Row 0...15 - - // B Matrix is read column wise. Threads T0...T15 read Row 0...3 of Col - // 0...15 (Each thread reads 1 column per 4 rows) Threads T16...T31 read - // Row 4...7 of Col 0...15 Threads T32...T47 read Row 8...11 of Col 0...15 - // Threads T48...T63 read Row 12...15 of Col 0...15 - int a_idx = (threadIdx.x / 16) * 4 + threadIdx.x % 16 * LDA; - int b_idx = (threadIdx.x / 16) * LDB * 4 + threadIdx.x % 16; + int a_idx = (threadIdx.x % 16) * LDA + (threadIdx.x / 16) * 4; + int b_idx = ((threadIdx.x % 4) + 4 * (threadIdx.x / 16)) * LDB + ((threadIdx.x % 16) / 4) * 4; flashinfer::gpu_iface::mma::load_fragment<__half>(a_reg, &A[a_idx]); - flashinfer::gpu_iface::mma::load_fragment_transpose<__half>(b_reg, &B[b_idx], LDB); - + flashinfer::gpu_iface::mma::load_fragment_transpose_4x4_half_registers<__half>(b_reg, &B[b_idx]); flashinfer::gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32<__half>(c_reg, a_reg, b_reg); for (int i = 0; i < 4; ++i) { - const int d_idx = threadIdx.x % 16 + i * LDC + (threadIdx.x / 16) * 4 * LDC; - + int d_idx = ((threadIdx.x / 16) * 4 + i) * LDC + (threadIdx.x % 16); C[d_idx] = c_reg[i]; } } From 06cde056f439bb6d2eda412f8d0a7f7c8bbb4abf Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 28 Sep 2025 10:36:06 -0400 Subject: [PATCH 2/4] Fix Array OOO access in debug function --- libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp index 4128c7f6cb..d8897ee124 100644 --- a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp +++ b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp @@ -32,7 +32,7 @@ __device__ void print_register(uint32_t* R) { __half2float(values[2]), __half2float(values[3])); } -__device__ void print_register(float* R) { printf("[%f %f %f %f]\n", R[0], R[1], R[3], R[4]); } +__device__ void print_register(float* R) { printf("[%f %f %f %f]\n", R[0], R[1], R[2], R[3]); } } // namespace From 243adce80b9e64d7962ad1e3bbfedbe12804609b Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 28 Sep 2025 14:53:28 -0400 Subject: [PATCH 3/4] Fixes based on Copilot review. --- .../include/gpu_iface/backend/hip/mma_hip.h | 32 ++++++++++--------- libflashinfer/include/gpu_iface/mma_ops.hpp | 10 ++---- .../tests/hip/test_mfma_fp32_16x16x16fp16.cpp | 14 +------- 3 files changed, 20 insertions(+), 36 deletions(-) diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 7429e40966..a6bc6b3fb0 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -12,16 +12,6 @@ using f16 = _Float16; using f16x4 = f16 __attribute__((ext_vector_type(4))); using f32x4 = float __attribute__((ext_vector_type(4))); -template -__device__ __forceinline__ f32x4 mfma_fp32_16x16x16fp16(f32x4 C, const f16x4 A, const f16x4 B) { - if constexpr (std::is_same_v) { - return __builtin_amdgcn_mfma_f32_16x16x16f16(A, B, C, 0, 0, 0); - } else if constexpr (std::is_same_v) { - return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(A, B, C, 0, 0, 0); - } - return C; -} - } // namespace namespace flashinfer { @@ -78,6 +68,7 @@ __device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) { template __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, uint32_t* A, uint32_t* B) { +#if defined(__HIP_DEVICE_COMPILE__) && (__gfx90a__ || __gfx908__ || __gfx942__) // Ensure T is either __half or __hip_bfloat16 static_assert(std::is_same_v || std::is_same_v, "T must be __half or __hip_bfloat16"); @@ -95,11 +86,22 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u f32x4 C_fp32 = reinterpret_cast(C)[0]; // Perform MMA operation directly with fragments - C_fp32 = mfma_fp32_16x16x16fp16(C_fp32, A_fp16, B_fp16); - C[0] = C_fp32[0]; - C[1] = C_fp32[1]; - C[2] = C_fp32[2]; - C[3] = C_fp32[3]; + + if constexpr (std::is_same_v) { + C_fp32 = __builtin_amdgcn_mfma_f32_16x16x16f16(A_fp16, B_fp16, C_fp32, 0, 0, 0); + } else if constexpr (std::is_same_v) { + C_fp32 = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(A_fp16, B_fp16, C_fp32, 0, 0, 0); + } + + reinterpret_cast(C)[0] = C_fp32; + + // C[0] = C_fp32[0]; + // C[1] = C_fp32[1]; + // C[2] = C_fp32[2]; + // C[3] = C_fp32[3]; +#elif defined(__HIP_DEVICE_COMPILE__) +#error "Unsupported GFX platform for MFMA ops. +#endif } /// @brief Loads a fragment from LDS to two 32bit registers and then transposes diff --git a/libflashinfer/include/gpu_iface/mma_ops.hpp b/libflashinfer/include/gpu_iface/mma_ops.hpp index 3cbff01067..6df297a96e 100644 --- a/libflashinfer/include/gpu_iface/mma_ops.hpp +++ b/libflashinfer/include/gpu_iface/mma_ops.hpp @@ -33,19 +33,13 @@ __device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) { mma_detail::load_fragment(R, smem_ptr); } -template -__device__ __forceinline__ void load_fragment_transpose(uint32_t* R, const T* smem_ptr, - uint32_t stride) { - mma_detail::load_fragment_transpose(R, smem_ptr, stride); -} - #if defined(PLATFORM_HIP_DEVICE) && defined(__gfx942__) template __device__ __forceinline__ void load_fragment_transpose_4x4_half_registers(uint32_t* R, const T* smem_ptr) { - static_assert(std::is_same::value, + static_assert(std::is_same::value, "Only __half is supported for the 4x4 register transpose"); - mma_detail::load_fragment_4x4_half_registers(R, smem_ptr); + mma_detail::load_fragment_4x4_half_registers<__half>(R, smem_ptr); } #endif diff --git a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp index d8897ee124..56c8f9391d 100644 --- a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp +++ b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp @@ -24,18 +24,6 @@ } \ } -namespace { - -__device__ void print_register(uint32_t* R) { - auto values = reinterpret_cast<__half*>(R); - printf("[%f %f %f %f]\n", __half2float(values[0]), __half2float(values[1]), - __half2float(values[2]), __half2float(values[3])); -} - -__device__ void print_register(float* R) { printf("[%f %f %f %f]\n", R[0], R[1], R[2], R[3]); } - -} // namespace - // Dimensions for our test matrices constexpr int M = 16; constexpr int N = 16; @@ -69,7 +57,7 @@ __global__ void test_mfma_kernel(const __half* A, const __half* B, float* C) { int b_idx = ((threadIdx.x % 4) + 4 * (threadIdx.x / 16)) * LDB + ((threadIdx.x % 16) / 4) * 4; flashinfer::gpu_iface::mma::load_fragment<__half>(a_reg, &A[a_idx]); - flashinfer::gpu_iface::mma::load_fragment_transpose_4x4_half_registers<__half>(b_reg, &B[b_idx]); + flashinfer::gpu_iface::mma_impl::hip::load_fragment_4x4_half_registers<__half>(b_reg, &B[b_idx]); flashinfer::gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32<__half>(c_reg, a_reg, b_reg); for (int i = 0; i < 4; ++i) { From 51bda7b34d086690c635d97fd5ee807192906045 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 28 Sep 2025 14:56:26 -0400 Subject: [PATCH 4/4] More Copilot review based fixes --- libflashinfer/include/gpu_iface/backend/hip/mma_hip.h | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index a6bc6b3fb0..5a1cffaf2f 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -94,13 +94,8 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u } reinterpret_cast(C)[0] = C_fp32; - - // C[0] = C_fp32[0]; - // C[1] = C_fp32[1]; - // C[2] = C_fp32[2]; - // C[3] = C_fp32[3]; #elif defined(__HIP_DEVICE_COMPILE__) -#error "Unsupported GFX platform for MFMA ops. +#error "Unsupported GFX platform for MFMA ops." #endif }