Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 57 additions & 69 deletions libflashinfer/include/gpu_iface/backend/hip/mma_hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@ using f16 = _Float16;
using f16x4 = f16 __attribute__((ext_vector_type(4)));
using f32x4 = float __attribute__((ext_vector_type(4)));

template <typename T>
__device__ __forceinline__ f32x4 mfma_fp32_16x16x16fp16(f32x4 C, const f16x4 A, const f16x4 B) {
if constexpr (std::is_same_v<T, __half>) {
return __builtin_amdgcn_mfma_f32_16x16x16f16(A, B, C, 0, 0, 0);
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(A, B, C, 0, 0, 0);
}
return C;
}

} // namespace

namespace flashinfer {
Expand All @@ -37,66 +27,48 @@ __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t* R) {
uint32_t lane_in_group = lane_id % 4;

// === ROUND 1: Exchange with neighbor (XOR with 1) ===
// T0T1, T2T3 partial exchange
uint32_t reg_idx = (lane_in_group >> 1) & 0x1;
uint32_t exchanged_val = __shfl_xor(R[reg_idx], 0x1);
// T0 <-> T1, T2 <-> T3 partial exchange
uint32_t regid = (lane_in_group >> 1) & 0x1;
uint32_t exchanged_val = __shfl_xor(R[regid], 0x1);
uint32_t shift = (lane_in_group & 1) * 16;
uint32_t keep_mask = 0xFFFF0000 >> shift;
int right_shift_amount = 16 * (1 - (lane_in_group & 1));
int left_shift_amount = 16 * (lane_in_group & 1);
R[reg_idx] =
(R[reg_idx] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount);
uint32_t keep_mask = 0x0000FFFF << shift;
int left_shift_amount = 16 * (1 - (lane_in_group & 1));
int right_shift_amount = 16 * (lane_in_group & 1);
R[regid] = (R[regid] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount);

// === ROUND 2: Exchange with one hop (XOR with 2) ===
// T0T2, T1T3 exchange R[0] and R[1]
// T0 <-> T2, T1 <-> T3 exchange R[0] and R[1]
// Swap entire registers based on thread position
uint32_t is_top = 1 - reg_idx;
uint32_t is_top = 1 - regid;
uint32_t temp0 = __shfl_xor(R[0], 0x2);
uint32_t temp1 = __shfl_xor(R[1], 0x2);

// Compute both possibilities and select
R[0] = R[0] * is_top + temp1 * reg_idx;
R[1] = temp0 * is_top + R[1] * reg_idx;
R[0] = R[0] * is_top + temp1 * regid;
R[1] = temp0 * is_top + R[1] * regid;

// === ROUND 3: Exchange with neighbor again (XOR with 1) ===
// T0T1, T2T3 exchange remaining parts
// T0 <-> T1, T2 <-> T3 exchange remaining parts

reg_idx = 1 - reg_idx;
exchanged_val = __shfl_xor(R[reg_idx], 0x1);
R[reg_idx] =
(R[reg_idx] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount);
regid = 1 - regid;
exchanged_val = __shfl_xor(R[regid], 0x1);
R[regid] = (R[regid] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount);
}

// Single unified load function for all fragment types
/// @param R [in] pointer to the register file to load the fragment into
/// @param smem_ptr [in] pointer to the shared memory to load the fragment from
template <typename T>
__device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) {
const uint16_t* v0 = reinterpret_cast<const uint16_t*>(smem_ptr) + 0;
const uint16_t* v1 = reinterpret_cast<const uint16_t*>(++smem_ptr);
const uint16_t* v2 = reinterpret_cast<const uint16_t*>(++smem_ptr);
const uint16_t* v3 = reinterpret_cast<const uint16_t*>(++smem_ptr);

R[0] = (static_cast<const uint32_t>(*v0) << 16) | static_cast<const uint32_t>(*v1);
R[1] = (static_cast<const uint32_t>(*v2) << 16) | static_cast<const uint32_t>(*v3);
}

template <typename T>
__device__ __forceinline__ void load_fragment_transpose(uint32_t* R, const T* smem_ptr,
uint32_t stride) {
const uint16_t* v0 = reinterpret_cast<const uint16_t*>(smem_ptr) + 0;
const uint16_t* v1 = reinterpret_cast<const uint16_t*>(smem_ptr + 1 * stride);
const uint16_t* v2 = reinterpret_cast<const uint16_t*>(smem_ptr + 2 * stride);
const uint16_t* v3 = reinterpret_cast<const uint16_t*>(smem_ptr + 3 * stride);

R[0] = (static_cast<const uint32_t>(*v0) << 16) | static_cast<const uint32_t>(*v1);
R[1] = (static_cast<const uint32_t>(*v2) << 16) | static_cast<const uint32_t>(*v3);
R[0] = reinterpret_cast<const uint32_t*>(smem_ptr)[0];
R[1] = reinterpret_cast<const uint32_t*>(smem_ptr)[1];
}

// MMA operation for FP16 inputs with FP32 accumulator
template <typename T, mma::MMAMode mma_mode = mma::MMAMode::kInplaceUpdate>
__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, uint32_t* A,
uint32_t* B) {
#if defined(__HIP_DEVICE_COMPILE__) && (__gfx90a__ || __gfx908__ || __gfx942__)
// Ensure T is either __half or __hip_bfloat16
static_assert(std::is_same_v<T, __half> || std::is_same_v<T, __hip_bfloat16>,
"T must be __half or __hip_bfloat16");
Expand All @@ -114,43 +86,59 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u
f32x4 C_fp32 = reinterpret_cast<f32x4*>(C)[0];

// Perform MMA operation directly with fragments
C_fp32 = mfma_fp32_16x16x16fp16<T>(C_fp32, A_fp16, B_fp16);
C[0] = C_fp32[0];
C[1] = C_fp32[1];
C[2] = C_fp32[2];
C[3] = C_fp32[3];

if constexpr (std::is_same_v<T, __half>) {
C_fp32 = __builtin_amdgcn_mfma_f32_16x16x16f16(A_fp16, B_fp16, C_fp32, 0, 0, 0);
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
C_fp32 = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(A_fp16, B_fp16, C_fp32, 0, 0, 0);
}

reinterpret_cast<f32x4*>(C)[0] = C_fp32;
#elif defined(__HIP_DEVICE_COMPILE__)
#error "Unsupported GFX platform for MFMA ops."
#endif
}

/// Loads a fragment from LDS to two 32bit registers and then transposes
/// @brief Loads a fragment from LDS to two 32bit registers and then transposes
/// the registers for a group of four consecuitive threads.
///
/// transposes the values in four adjacent threads. The function does the
/// following layout transformation:
/// Original data in registers for Threads 0-3 after fragment load
/// T0 : a b c d
/// T1 : e f g h
/// T2 : i j k l
/// T3 : m n o p
///
/// After transposition:
/// T0 : a e i m
/// T1 : b f j n
/// T2 : c g k o
/// T3 : d h l p
template <typename T>
__device__ __forceinline__ void load_fragment_4x4_half_registers(uint32_t* R, const T* smem_ptr) {
static_assert(std::is_same_v<T, __half>(), "Only half type is supported");
// Each thread loads 4 __half values in two 32b registers.
static_assert(std::is_same_v<T, __half>, "Only half type is supported");
Copy link

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The static_assert uses std::is_same_v which requires C++17, but the codebase appears to use C++14/C++11 style elsewhere (e.g., std::is_same<T, __half>::value in the other file). This should be consistent with the project's C++ standard.

Suggested change
static_assert(std::is_same_v<T, __half>, "Only half type is supported");
static_assert(std::is_same<T, __half>::value, "Only half type is supported");

Copilot uses AI. Check for mistakes.
load_fragment(R, smem_ptr);
// transposes the values in four adjacent threads. The function does the
// following layout transformation:
// Original data in registers for Threads 0-3 after fragment load
// T0 : a b c d
// T1 : e f g h
// T2 : i j k l
// T3 : m n o p
//
// After transposition:
// T0 : a e i m
// T1 : b f j n
// T2 : c g k o
// T3 : d h l p

transpose_4x4_half_registers(R);
}

// TODO: Verify correct matrix multiplication order for rowsum on CDNA3
// Current assumption: s_frag × ones_vector = row_sums
// Need to validate:
// 1. How compute_qk stores Q×K^T result in s_frag for CDNA3
// 2. Whether K is pre-transposed or transposed during fragment loading
// 3. If we need s_frag × M1 or M1 × s_frag for correct row sums
//
// Test with known input matrices to verify:
// - s_frag layout matches expected Q×K^T result
// - rowsum produces correct per-row sums
template <typename DType>
__device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s_frag) {
static_assert(sizeof(DType) == 2, "DType must be 16-bit type");
transpose_4x4_half_registers(reinterpret_cast<uint32_t*>(s_frag));
f16x4 a = reinterpret_cast<const f16x4*>(s_frag)[0];
f16x4 b = {f16(1.0f), f16(1.0f), f16(1.0f), f16(1.0f)};
f32x4 c = {0.f, 0.f, 0.f, 0.f};
f32x4 c = {d[0], d[1], d[2], d[3]};
f32x4 out = __builtin_amdgcn_mfma_f32_16x16x16f16(a, b, c, 0, 0, 0);
d[0] = out.x;
d[1] = out.y;
Expand Down
10 changes: 2 additions & 8 deletions libflashinfer/include/gpu_iface/mma_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,13 @@ __device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) {
mma_detail::load_fragment<T>(R, smem_ptr);
}

template <typename T>
__device__ __forceinline__ void load_fragment_transpose(uint32_t* R, const T* smem_ptr,
uint32_t stride) {
mma_detail::load_fragment_transpose<T>(R, smem_ptr, stride);
}

#if defined(PLATFORM_HIP_DEVICE) && defined(__gfx942__)
template <typename T>
__device__ __forceinline__ void load_fragment_transpose_4x4_half_registers(uint32_t* R,
const T* smem_ptr) {
static_assert(std::is_same<T, int>::value,
static_assert(std::is_same<T, __half>::value,
"Only __half is supported for the 4x4 register transpose");
Comment on lines 37 to 41
Copy link

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The static_assert checks for type T but the function call always uses __half. This creates a mismatch - either the template parameter T should be used in the function call, or the static_assert should be removed since T is not used.

Copilot uses AI. Check for mistakes.
mma_detail::load_fragment_4x4_half_registers<half>(R, smem_ptr);
mma_detail::load_fragment_4x4_half_registers<__half>(R, smem_ptr);
}
#endif

Expand Down
20 changes: 4 additions & 16 deletions libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ void gemm_reference(const __half* A, const __half* B, float* C, int M, int N, in
for (int j = 0; j < N; ++j) {
float acc = 0.0f;
for (int k = 0; k < K; ++k) {
// Use __half_as_float to properly convert __half to float
acc += __half2float(A[i * K + k]) * __half2float(B[k * N + j]);
}
C[i * N + j] = acc;
Expand All @@ -54,26 +53,15 @@ __global__ void test_mfma_kernel(const __half* A, const __half* B, float* C) {
uint32_t b_reg[2];
float c_reg[4] = {0.0f, 0.0f, 0.0f, 0.0f};

// A Matrix is read row wise. Threads T0...T15 read Col 0...3 of Row 0...15
// Threads T16...T31 read Col 4...7 of Row 0...15
// Threads T32...T47 read Col 8...11 of Row 0...15
// Threads T48...T63 read Col 12...15 of Row 0...15

// B Matrix is read column wise. Threads T0...T15 read Row 0...3 of Col
// 0...15 (Each thread reads 1 column per 4 rows) Threads T16...T31 read
// Row 4...7 of Col 0...15 Threads T32...T47 read Row 8...11 of Col 0...15
// Threads T48...T63 read Row 12...15 of Col 0...15
int a_idx = (threadIdx.x / 16) * 4 + threadIdx.x % 16 * LDA;
int b_idx = (threadIdx.x / 16) * LDB * 4 + threadIdx.x % 16;
int a_idx = (threadIdx.x % 16) * LDA + (threadIdx.x / 16) * 4;
int b_idx = ((threadIdx.x % 4) + 4 * (threadIdx.x / 16)) * LDB + ((threadIdx.x % 16) / 4) * 4;

flashinfer::gpu_iface::mma::load_fragment<__half>(a_reg, &A[a_idx]);
flashinfer::gpu_iface::mma::load_fragment_transpose<__half>(b_reg, &B[b_idx], LDB);

flashinfer::gpu_iface::mma_impl::hip::load_fragment_4x4_half_registers<__half>(b_reg, &B[b_idx]);
flashinfer::gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32<__half>(c_reg, a_reg, b_reg);

for (int i = 0; i < 4; ++i) {
const int d_idx = threadIdx.x % 16 + i * LDC + (threadIdx.x / 16) * 4 * LDC;

int d_idx = ((threadIdx.x / 16) * 4 + i) * LDC + (threadIdx.x % 16);
C[d_idx] = c_reg[i];
}
}
Expand Down
Loading